|
import os |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import math |
|
from diffusers import DiffusionPipeline |
|
from einops import rearrange, repeat |
|
from itertools import chain |
|
from tqdm import tqdm |
|
from .geometry import get_batch_from_spherical |
|
|
|
class SPADPipeline(DiffusionPipeline): |
|
def __init__(self, unet, vae, text_encoder, tokenizer, scheduler): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
unet=unet, |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
scheduler=scheduler |
|
) |
|
|
|
self.cfg_conds = ["txt", "cam", "epi", "plucker"] |
|
self.cfg_scales = [7.5, 1.0, 1.0, 1.0] |
|
self.use_abs_extrinsics = False |
|
self.use_intrinsic = False |
|
|
|
self.cc_projection = nn.Sequential( |
|
nn.Linear(4 if not self.use_intrinsic else 8, 1280), |
|
nn.SiLU(), |
|
nn.Linear(1280, 1280), |
|
).to(device) |
|
nn.init.zeros_(self.cc_projection[-1].weight) |
|
nn.init.zeros_(self.cc_projection[-1].bias) |
|
|
|
|
|
def generate_camera_batch(self, elevations, azimuths, use_abs=False): |
|
batch = get_batch_from_spherical(elevations, azimuths) |
|
|
|
abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)] |
|
|
|
debug_cams = [[] for _ in range(len(azimuths))] |
|
for i, icam in enumerate(abs_cams): |
|
for j, jcam in enumerate(abs_cams): |
|
if use_abs: |
|
dcam = torch.tensor([icam[0], math.sin(icam[1]), math.cos(icam[1]), icam[2]]) |
|
else: |
|
dcam = icam - jcam |
|
dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()]) |
|
debug_cams[i].append(dcam) |
|
|
|
batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams]) |
|
|
|
|
|
focal = 1 / np.tan(0.702769935131073 / 2) |
|
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32) |
|
intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1) |
|
batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]] |
|
|
|
return batch |
|
|
|
def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5): |
|
X = np.linspace(-1, 1, blob_width)[None, :] |
|
Y = np.linspace(-1, 1, blob_height)[:, None] |
|
inv_dev = 1 / sigma ** 2 |
|
gaussian_blob = np.exp(-0.5 * (X**2) * inv_dev) * np.exp(-0.5 * (Y**2) * inv_dev) |
|
if gaussian_blob.max() > 0: |
|
gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max() |
|
gaussian_blob = 255.0 - gaussian_blob |
|
|
|
gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0 |
|
gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1) |
|
gaussian_blob = torch.from_numpy(gaussian_blob) |
|
|
|
return gaussian_blob |
|
|
|
@torch.no_grad() |
|
def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1, |
|
elevations=None, azimuths=None, blob_sigma=0.5, **kwargs): |
|
batch_size = len(prompt) if isinstance(prompt, list) else 1 |
|
device = self.device |
|
|
|
|
|
if elevations is None or azimuths is None: |
|
elevations = [45] * 4 |
|
azimuths = [0, 90, 180, 270] |
|
|
|
n_views = len(elevations) |
|
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics) |
|
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()} |
|
|
|
|
|
blob = self.get_gaussian_image(sigma=blob_sigma).to(device) |
|
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1) |
|
|
|
|
|
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device) |
|
text_embeddings = self.text_encoder(text_input_ids)[0] |
|
|
|
|
|
max_length = text_input_ids.shape[-1] |
|
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt") |
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] |
|
|
|
|
|
camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device) |
|
|
|
|
|
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8 |
|
latents = self.prepare_latents( |
|
batch_size, |
|
self.unet.in_channels, |
|
n_views, |
|
latent_height, |
|
latent_width, |
|
self.unet.dtype, |
|
device, |
|
generator=None, |
|
) |
|
|
|
|
|
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device) |
|
|
|
|
|
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device) |
|
|
|
latent_height, latent_width = 64, 64 |
|
n_objects = 2; |
|
latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype) |
|
|
|
|
|
|
|
self.scheduler.set_timesteps(50) |
|
|
|
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) |
|
|
|
|
|
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1) |
|
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1) |
|
|
|
for t in tqdm(self.scheduler.timesteps): |
|
|
|
|
|
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long) |
|
|
|
|
|
context = [ |
|
text_embeddings.to(device), |
|
camera_embeddings, |
|
torch.ones(n_objects, n_views, 6, 32, 32).to(device) |
|
] |
|
|
|
|
|
noise_pred = self.unet( |
|
latents.to(device), |
|
timesteps=timesteps.to(device), |
|
context=context |
|
) |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
|
|
latents_reshaped = latents[:, 0, :, :, :] |
|
|
|
|
|
images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0] |
|
|
|
|
|
images = (images / 2 + 0.5).clamp(0, 1) |
|
|
|
if images.dim() == 5: |
|
images_output = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() |
|
elif images.dim() == 4: |
|
images_output = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
else: |
|
raise ValueError(f"Unexpected image dimensions: {images.shape}") |
|
|
|
|
|
return {"images": images, "nsfw_content_detected": [[False] * n_views for _ in range(batch_size)]} |
|
|
|
def prepare_latents(self, batch_size, num_channels, num_views, height, width, dtype, device, generator=None): |
|
shape = (batch_size, num_views, num_channels, height, width) |
|
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) |
|
return latents |