jadechoghari commited on
Commit
ef48aca
1 Parent(s): a1c4d22

Update pipeline_spad.py

Browse files
Files changed (1) hide show
  1. pipeline_spad.py +28 -37
pipeline_spad.py CHANGED
@@ -12,7 +12,7 @@ from .geometry import get_batch_from_spherical
12
  class SPADPipeline(DiffusionPipeline):
13
  def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
14
  super().__init__()
15
-
16
  self.register_modules(
17
  unet=unet,
18
  vae=vae,
@@ -20,26 +20,26 @@ class SPADPipeline(DiffusionPipeline):
20
  tokenizer=tokenizer,
21
  scheduler=scheduler
22
  )
23
-
24
  self.cfg_conds = ["txt", "cam", "epi", "plucker"]
25
  self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
26
  self.use_abs_extrinsics = False
27
  self.use_intrinsic = False
28
-
29
  self.cc_projection = nn.Sequential(
30
  nn.Linear(4 if not self.use_intrinsic else 8, 1280),
31
  nn.SiLU(),
32
  nn.Linear(1280, 1280),
33
- )
34
  nn.init.zeros_(self.cc_projection[-1].weight)
35
  nn.init.zeros_(self.cc_projection[-1].bias)
36
 
37
 
38
  def generate_camera_batch(self, elevations, azimuths, use_abs=False):
39
  batch = get_batch_from_spherical(elevations, azimuths)
40
-
41
  abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
42
-
43
  debug_cams = [[] for _ in range(len(azimuths))]
44
  for i, icam in enumerate(abs_cams):
45
  for j, jcam in enumerate(abs_cams):
@@ -49,15 +49,15 @@ class SPADPipeline(DiffusionPipeline):
49
  dcam = icam - jcam
50
  dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
51
  debug_cams[i].append(dcam)
52
-
53
  batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
54
-
55
  # Add intrinsics to the batch
56
  focal = 1 / np.tan(0.702769935131073 / 2)
57
  intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
58
  intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
59
  batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
60
-
61
  return batch
62
 
63
  def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
@@ -68,15 +68,15 @@ class SPADPipeline(DiffusionPipeline):
68
  if gaussian_blob.max() > 0:
69
  gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
70
  gaussian_blob = 255.0 - gaussian_blob
71
-
72
  gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
73
  gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
74
  gaussian_blob = torch.from_numpy(gaussian_blob)
75
-
76
  return gaussian_blob
77
 
78
  @torch.no_grad()
79
- def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
80
  elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
81
  batch_size = len(prompt) if isinstance(prompt, list) else 1
82
  device = self.device
@@ -85,7 +85,7 @@ class SPADPipeline(DiffusionPipeline):
85
  if elevations is None or azimuths is None:
86
  elevations = [45] * 4
87
  azimuths = [0, 90, 180, 270]
88
-
89
  n_views = len(elevations)
90
  camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
91
  camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
@@ -104,7 +104,7 @@ class SPADPipeline(DiffusionPipeline):
104
  uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
105
 
106
  # Encode camera data
107
- camera_embeddings = self.cc_projection(camera_batch["cam"])
108
 
109
  # Prepare latents
110
  latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
@@ -127,45 +127,37 @@ class SPADPipeline(DiffusionPipeline):
127
 
128
  latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
129
  n_objects = 2;
130
- latents = torch.randn(n_objects, n_views, 10, 32, 32, device=device, dtype=self.unet.dtype)
131
 
132
  # Set up scheduler
133
  # self.scheduler.set_timesteps(num_inference_steps)
134
- self.scheduler.set_timesteps(10)
135
  # Repeat text_embeddings to match the desired dimensions
136
  text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
137
 
138
  # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
139
  text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
 
140
  # Denoising loop
141
  for t in tqdm(self.scheduler.timesteps):
142
  # Expand timesteps to match shape [batch_size, 1, 1]
143
  # timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
144
  timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
145
-
146
- # # Repeat text_embeddings to match the desired dimensions
147
- # text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
148
-
149
- # # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
150
- # text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
151
-
152
- # print("old cam shape: ", camera_embeddings.shape)
153
- camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
154
- # print("cam emb shape: ", camera_embeddings.shape)
155
  # Prepare context
156
  context = [
157
  # text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
158
  # camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
159
  # epi_constraint_masks # Keep this as is for now
160
- text_embeddings, # [n_objects, n_views, max_seq_len, 768]
161
- camera_embeddings # [n_objects, n_views, 1280]
162
- torch.ones(n_objects, n_views, 6, 32, 32)
163
  ]
164
 
165
  # Predict noise residual
166
  noise_pred = self.unet(
167
- latents, # Shape: [batch_size, 1, 4, 64, 64]
168
- timesteps=timesteps, # Shape: [batch_size, 1, 1]
169
  context=context
170
  )
171
 
@@ -179,19 +171,18 @@ class SPADPipeline(DiffusionPipeline):
179
 
180
  # reduce latents
181
  #EXPERIMENTAL
182
- # If you need to reduce the channels from 10 to 4
183
- latents = latents[:, :, :4, :, :] # Select only the first 4 channels
184
- latents = latents.view(-1, latents.shape[2], latents.shape[3], latents.shape[4])
185
  # Decode latents
186
- images = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
187
 
188
  # Post-process images
189
  images = (images / 2 + 0.5).clamp(0, 1)
190
 
191
  if images.dim() == 5:
192
- images = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors
193
  elif images.dim() == 4:
194
- images = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors
195
  else:
196
  raise ValueError(f"Unexpected image dimensions: {images.shape}")
197
 
 
12
  class SPADPipeline(DiffusionPipeline):
13
  def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
14
  super().__init__()
15
+
16
  self.register_modules(
17
  unet=unet,
18
  vae=vae,
 
20
  tokenizer=tokenizer,
21
  scheduler=scheduler
22
  )
23
+
24
  self.cfg_conds = ["txt", "cam", "epi", "plucker"]
25
  self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
26
  self.use_abs_extrinsics = False
27
  self.use_intrinsic = False
28
+
29
  self.cc_projection = nn.Sequential(
30
  nn.Linear(4 if not self.use_intrinsic else 8, 1280),
31
  nn.SiLU(),
32
  nn.Linear(1280, 1280),
33
+ ).to(device)
34
  nn.init.zeros_(self.cc_projection[-1].weight)
35
  nn.init.zeros_(self.cc_projection[-1].bias)
36
 
37
 
38
  def generate_camera_batch(self, elevations, azimuths, use_abs=False):
39
  batch = get_batch_from_spherical(elevations, azimuths)
40
+
41
  abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
42
+
43
  debug_cams = [[] for _ in range(len(azimuths))]
44
  for i, icam in enumerate(abs_cams):
45
  for j, jcam in enumerate(abs_cams):
 
49
  dcam = icam - jcam
50
  dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
51
  debug_cams[i].append(dcam)
52
+
53
  batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
54
+
55
  # Add intrinsics to the batch
56
  focal = 1 / np.tan(0.702769935131073 / 2)
57
  intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
58
  intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
59
  batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
60
+
61
  return batch
62
 
63
  def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
 
68
  if gaussian_blob.max() > 0:
69
  gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
70
  gaussian_blob = 255.0 - gaussian_blob
71
+
72
  gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
73
  gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
74
  gaussian_blob = torch.from_numpy(gaussian_blob)
75
+
76
  return gaussian_blob
77
 
78
  @torch.no_grad()
79
+ def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
80
  elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
81
  batch_size = len(prompt) if isinstance(prompt, list) else 1
82
  device = self.device
 
85
  if elevations is None or azimuths is None:
86
  elevations = [45] * 4
87
  azimuths = [0, 90, 180, 270]
88
+
89
  n_views = len(elevations)
90
  camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
91
  camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
 
104
  uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
105
 
106
  # Encode camera data
107
+ camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
108
 
109
  # Prepare latents
110
  latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
 
127
 
128
  latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
129
  n_objects = 2;
130
+ latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
131
 
132
  # Set up scheduler
133
  # self.scheduler.set_timesteps(num_inference_steps)
134
+ self.scheduler.set_timesteps(50)
135
  # Repeat text_embeddings to match the desired dimensions
136
  text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
137
 
138
  # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
139
  text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
140
+ camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
141
  # Denoising loop
142
  for t in tqdm(self.scheduler.timesteps):
143
  # Expand timesteps to match shape [batch_size, 1, 1]
144
  # timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
145
  timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
146
+
 
 
 
 
 
 
 
 
 
147
  # Prepare context
148
  context = [
149
  # text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
150
  # camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
151
  # epi_constraint_masks # Keep this as is for now
152
+ text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
153
+ camera_embeddings, # [n_objects, n_views, 1280]
154
+ torch.ones(n_objects, n_views, 6, 32, 32).to(device)
155
  ]
156
 
157
  # Predict noise residual
158
  noise_pred = self.unet(
159
+ latents.to(device), # Shape: [batch_size, 1, 4, 64, 64]
160
+ timesteps=timesteps.to(device), # Shape: [batch_size, 1, 1]
161
  context=context
162
  )
163
 
 
171
 
172
  # reduce latents
173
  #EXPERIMENTAL
174
+ latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
175
+
 
176
  # Decode latents
177
+ images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
178
 
179
  # Post-process images
180
  images = (images / 2 + 0.5).clamp(0, 1)
181
 
182
  if images.dim() == 5:
183
+ images_output = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors
184
  elif images.dim() == 4:
185
+ images_output = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors
186
  else:
187
  raise ValueError(f"Unexpected image dimensions: {images.shape}")
188