jadechoghari commited on
Commit
8056866
1 Parent(s): c6416ec

Update pipeline_spad.py

Browse files
Files changed (1) hide show
  1. pipeline_spad.py +20 -23
pipeline_spad.py CHANGED
@@ -81,7 +81,7 @@ class SPADPipeline(DiffusionPipeline):
81
  batch_size = len(prompt) if isinstance(prompt, list) else 1
82
  device = self.device
83
 
84
- # Generate camera batch
85
  if elevations is None or azimuths is None:
86
  elevations = [45] * 4
87
  azimuths = [0, 90, 180, 270]
@@ -90,23 +90,23 @@ class SPADPipeline(DiffusionPipeline):
90
  camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
91
  camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
92
 
93
- # Prepare gaussian blob initialization
94
  blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
95
  camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
96
 
97
- # Encode text
98
  text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
99
  text_embeddings = self.text_encoder(text_input_ids)[0]
100
 
101
- # Prepare unconditional embeddings for classifier-free guidance
102
  max_length = text_input_ids.shape[-1]
103
  uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
104
  uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
105
 
106
- # Encode camera data
107
  camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
108
 
109
- # Prepare latents
110
  latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
111
  latents = self.prepare_latents(
112
  batch_size,
@@ -119,36 +119,33 @@ class SPADPipeline(DiffusionPipeline):
119
  generator=None,
120
  )
121
 
122
- # Prepare epi_constraint_masks (placeholder, replace with actual implementation)
123
  epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
124
 
125
- # Prepare plucker embeddings (placeholder, replace with actual implementation)
126
  plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
127
 
128
  latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
129
  n_objects = 2;
130
  latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
131
 
132
- # Set up scheduler
133
  # self.scheduler.set_timesteps(num_inference_steps)
134
  self.scheduler.set_timesteps(50)
135
- # Repeat text_embeddings to match the desired dimensions
136
- text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
137
 
138
- # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
139
  text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
140
  camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
141
- # Denoising loop
142
  for t in tqdm(self.scheduler.timesteps):
143
- # Expand timesteps to match shape [batch_size, 1, 1]
144
  # timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
145
  timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
146
 
147
- # Prepare context
148
  context = [
149
- # text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
150
- # camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
151
- # epi_constraint_masks # Keep this as is for now
152
  text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
153
  camera_embeddings, # [n_objects, n_views, 1280]
154
  torch.ones(n_objects, n_views, 6, 32, 32).to(device)
@@ -161,22 +158,22 @@ class SPADPipeline(DiffusionPipeline):
161
  context=context
162
  )
163
 
164
- # Perform guidance
165
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
166
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
167
 
168
 
169
- # Compute previous noisy sample
170
  latents = self.scheduler.step(noise_pred, t, latents).prev_sample
171
 
172
  # reduce latents
173
- #EXPERIMENTAL
174
  latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
175
 
176
- # Decode latents
177
  images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
178
 
179
- # Post-process images
180
  images = (images / 2 + 0.5).clamp(0, 1)
181
 
182
  if images.dim() == 5:
 
81
  batch_size = len(prompt) if isinstance(prompt, list) else 1
82
  device = self.device
83
 
84
+ # generate camera batch
85
  if elevations is None or azimuths is None:
86
  elevations = [45] * 4
87
  azimuths = [0, 90, 180, 270]
 
90
  camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
91
  camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
92
 
93
+ # prepare gaussian blob initialization
94
  blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
95
  camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
96
 
97
+ # encode text
98
  text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
99
  text_embeddings = self.text_encoder(text_input_ids)[0]
100
 
101
+ # prepare unconditional embeddings for classifier-free guidance
102
  max_length = text_input_ids.shape[-1]
103
  uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
104
  uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
105
 
106
+ # encode camera data
107
  camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
108
 
109
+ # prepare latents
110
  latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
111
  latents = self.prepare_latents(
112
  batch_size,
 
119
  generator=None,
120
  )
121
 
122
+ # prepare epi_constraint_masks (placeholder- replace with actual implementation later - MIGHT AFFECT PERFORMANCE)
123
  epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
124
 
125
+ # prepare plucker embeddings (placeholder, replace with actual implementation - MIGHT AFFECT PERFORMANCE)
126
  plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
127
 
128
  latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
129
  n_objects = 2;
130
  latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
131
 
132
+ # set up scheduler
133
  # self.scheduler.set_timesteps(num_inference_steps)
134
  self.scheduler.set_timesteps(50)
135
+ # repeat text_embeddings to match the desired dimensions
136
+ text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 768]
137
 
138
+ # reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
139
  text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
140
  camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
141
+ # denoising loop
142
  for t in tqdm(self.scheduler.timesteps):
143
+ # expand timesteps to match shape [batch_size, 1, 1]
144
  # timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
145
  timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
146
 
147
+ # prepare context
148
  context = [
 
 
 
149
  text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
150
  camera_embeddings, # [n_objects, n_views, 1280]
151
  torch.ones(n_objects, n_views, 6, 32, 32).to(device)
 
158
  context=context
159
  )
160
 
161
+ # perform guidance
162
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
163
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
164
 
165
 
166
+ # compute previous noisy sample
167
  latents = self.scheduler.step(noise_pred, t, latents).prev_sample
168
 
169
  # reduce latents
170
+ #EXPERIMENTAL - MIGHT AFFECT PERFORMANCE
171
  latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
172
 
173
+ # decode latents
174
  images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
175
 
176
+ # post-process images
177
  images = (images / 2 + 0.5).clamp(0, 1)
178
 
179
  if images.dim() == 5: