import os import torch import torch.nn as nn import numpy as np import math from diffusers import DiffusionPipeline from einops import rearrange, repeat from itertools import chain from tqdm import tqdm from .geometry import get_batch_from_spherical class SPADPipeline(DiffusionPipeline): def __init__(self, unet, vae, text_encoder, tokenizer, scheduler): super().__init__() self.register_modules( unet=unet, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler ) self.cfg_conds = ["txt", "cam", "epi", "plucker"] self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed self.use_abs_extrinsics = False self.use_intrinsic = False self.cc_projection = nn.Sequential( nn.Linear(4 if not self.use_intrinsic else 8, 1280), nn.SiLU(), nn.Linear(1280, 1280), ) nn.init.zeros_(self.cc_projection[-1].weight) nn.init.zeros_(self.cc_projection[-1].bias) def generate_camera_batch(self, elevations, azimuths, use_abs=False): batch = get_batch_from_spherical(elevations, azimuths) abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)] debug_cams = [[] for _ in range(len(azimuths))] for i, icam in enumerate(abs_cams): for j, jcam in enumerate(abs_cams): if use_abs: dcam = torch.tensor([icam[0], math.sin(icam[1]), math.cos(icam[1]), icam[2]]) else: dcam = icam - jcam dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()]) debug_cams[i].append(dcam) batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams]) # Add intrinsics to the batch focal = 1 / np.tan(0.702769935131073 / 2) intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32) intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1) batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]] return batch def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5): X = np.linspace(-1, 1, blob_width)[None, :] Y = np.linspace(-1, 1, blob_height)[:, None] inv_dev = 1 / sigma ** 2 gaussian_blob = np.exp(-0.5 * (X**2) * inv_dev) * np.exp(-0.5 * (Y**2) * inv_dev) if gaussian_blob.max() > 0: gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max() gaussian_blob = 255.0 - gaussian_blob gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0 gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1) gaussian_blob = torch.from_numpy(gaussian_blob) return gaussian_blob @torch.no_grad() def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1, elevations=None, azimuths=None, blob_sigma=0.5, **kwargs): batch_size = len(prompt) if isinstance(prompt, list) else 1 device = self.device # Generate camera batch if elevations is None or azimuths is None: elevations = [45] * 4 azimuths = [0, 90, 180, 270] n_views = len(elevations) camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics) camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()} # Prepare gaussian blob initialization blob = self.get_gaussian_image(sigma=blob_sigma).to(device) camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1) # Encode text text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device) text_embeddings = self.text_encoder(text_input_ids)[0] # Prepare unconditional embeddings for classifier-free guidance max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt") uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] # Encode camera data camera_embeddings = self.cc_projection(camera_batch["cam"]) # Prepare latents latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8 latents = self.prepare_latents( batch_size, self.unet.in_channels, n_views, latent_height, latent_width, self.unet.dtype, device, generator=None, ) # Prepare epi_constraint_masks (placeholder, replace with actual implementation) epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device) # Prepare plucker embeddings (placeholder, replace with actual implementation) plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device) latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64] n_objects = 2; latents = torch.randn(n_objects, n_views, 10, 32, 32, device=device, dtype=self.unet.dtype) # Set up scheduler # self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(10) # Repeat text_embeddings to match the desired dimensions text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512] # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512] text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1) # Denoising loop for t in tqdm(self.scheduler.timesteps): # Expand timesteps to match shape [batch_size, 1, 1] # timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long) timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long) # # Repeat text_embeddings to match the desired dimensions # text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512] # # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512] # text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1) # print("old cam shape: ", camera_embeddings.shape) camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1) # print("cam emb shape: ", camera_embeddings.shape) # Prepare context context = [ # text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768] # camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0 # epi_constraint_masks # Keep this as is for now text_embeddings, # [n_objects, n_views, max_seq_len, 768] camera_embeddings # [n_objects, n_views, 1280] ] # Predict noise residual noise_pred = self.unet( latents, # Shape: [batch_size, 1, 4, 64, 64] timesteps=timesteps, # Shape: [batch_size, 1, 1] context=context ) # Perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # Compute previous noisy sample latents = self.scheduler.step(noise_pred, t, latents).prev_sample # reduce latents #EXPERIMENTAL # If you need to reduce the channels from 10 to 4 latents = latents[:, :, :4, :, :] # Select only the first 4 channels latents = latents.view(-1, latents.shape[2], latents.shape[3], latents.shape[4]) # Decode latents images = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # Post-process images images = (images / 2 + 0.5).clamp(0, 1) if images.dim() == 5: images = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors elif images.dim() == 4: images = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors else: raise ValueError(f"Unexpected image dimensions: {images.shape}") return {"images": images, "nsfw_content_detected": [[False] * n_views for _ in range(batch_size)]} def prepare_latents(self, batch_size, num_channels, num_views, height, width, dtype, device, generator=None): shape = (batch_size, num_views, num_channels, height, width) latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) return latents