spad / pipeline_spad.py
jadechoghari's picture
Update pipeline_spad.py
8056866 verified
import os
import torch
import torch.nn as nn
import numpy as np
import math
from diffusers import DiffusionPipeline
from einops import rearrange, repeat
from itertools import chain
from tqdm import tqdm
from .geometry import get_batch_from_spherical
class SPADPipeline(DiffusionPipeline):
def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
super().__init__()
self.register_modules(
unet=unet,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler
)
self.cfg_conds = ["txt", "cam", "epi", "plucker"]
self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
self.use_abs_extrinsics = False
self.use_intrinsic = False
self.cc_projection = nn.Sequential(
nn.Linear(4 if not self.use_intrinsic else 8, 1280),
nn.SiLU(),
nn.Linear(1280, 1280),
).to(device)
nn.init.zeros_(self.cc_projection[-1].weight)
nn.init.zeros_(self.cc_projection[-1].bias)
def generate_camera_batch(self, elevations, azimuths, use_abs=False):
batch = get_batch_from_spherical(elevations, azimuths)
abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
debug_cams = [[] for _ in range(len(azimuths))]
for i, icam in enumerate(abs_cams):
for j, jcam in enumerate(abs_cams):
if use_abs:
dcam = torch.tensor([icam[0], math.sin(icam[1]), math.cos(icam[1]), icam[2]])
else:
dcam = icam - jcam
dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
debug_cams[i].append(dcam)
batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
# Add intrinsics to the batch
focal = 1 / np.tan(0.702769935131073 / 2)
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
return batch
def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
X = np.linspace(-1, 1, blob_width)[None, :]
Y = np.linspace(-1, 1, blob_height)[:, None]
inv_dev = 1 / sigma ** 2
gaussian_blob = np.exp(-0.5 * (X**2) * inv_dev) * np.exp(-0.5 * (Y**2) * inv_dev)
if gaussian_blob.max() > 0:
gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
gaussian_blob = 255.0 - gaussian_blob
gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
gaussian_blob = torch.from_numpy(gaussian_blob)
return gaussian_blob
@torch.no_grad()
def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
device = self.device
# generate camera batch
if elevations is None or azimuths is None:
elevations = [45] * 4
azimuths = [0, 90, 180, 270]
n_views = len(elevations)
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
# prepare gaussian blob initialization
blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
# encode text
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
text_embeddings = self.text_encoder(text_input_ids)[0]
# prepare unconditional embeddings for classifier-free guidance
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
# encode camera data
camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
# prepare latents
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
latents = self.prepare_latents(
batch_size,
self.unet.in_channels,
n_views,
latent_height,
latent_width,
self.unet.dtype,
device,
generator=None,
)
# prepare epi_constraint_masks (placeholder- replace with actual implementation later - MIGHT AFFECT PERFORMANCE)
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
# prepare plucker embeddings (placeholder, replace with actual implementation - MIGHT AFFECT PERFORMANCE)
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
n_objects = 2;
latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
# set up scheduler
# self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_timesteps(50)
# repeat text_embeddings to match the desired dimensions
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 768]
# reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
# denoising loop
for t in tqdm(self.scheduler.timesteps):
# expand timesteps to match shape [batch_size, 1, 1]
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
# prepare context
context = [
text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
camera_embeddings, # [n_objects, n_views, 1280]
torch.ones(n_objects, n_views, 6, 32, 32).to(device)
]
# Predict noise residual
noise_pred = self.unet(
latents.to(device), # Shape: [batch_size, 1, 4, 64, 64]
timesteps=timesteps.to(device), # Shape: [batch_size, 1, 1]
context=context
)
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute previous noisy sample
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# reduce latents
#EXPERIMENTAL - MIGHT AFFECT PERFORMANCE
latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
# decode latents
images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
# post-process images
images = (images / 2 + 0.5).clamp(0, 1)
if images.dim() == 5:
images_output = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors
elif images.dim() == 4:
images_output = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors
else:
raise ValueError(f"Unexpected image dimensions: {images.shape}")
return {"images": images, "nsfw_content_detected": [[False] * n_views for _ in range(batch_size)]}
def prepare_latents(self, batch_size, num_channels, num_views, height, width, dtype, device, generator=None):
shape = (batch_size, num_views, num_channels, height, width)
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return latents