jadechoghari
commited on
Commit
•
8056866
1
Parent(s):
c6416ec
Update pipeline_spad.py
Browse files- pipeline_spad.py +20 -23
pipeline_spad.py
CHANGED
@@ -81,7 +81,7 @@ class SPADPipeline(DiffusionPipeline):
|
|
81 |
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
82 |
device = self.device
|
83 |
|
84 |
-
#
|
85 |
if elevations is None or azimuths is None:
|
86 |
elevations = [45] * 4
|
87 |
azimuths = [0, 90, 180, 270]
|
@@ -90,23 +90,23 @@ class SPADPipeline(DiffusionPipeline):
|
|
90 |
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
|
91 |
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
|
92 |
|
93 |
-
#
|
94 |
blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
|
95 |
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
|
96 |
|
97 |
-
#
|
98 |
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
|
99 |
text_embeddings = self.text_encoder(text_input_ids)[0]
|
100 |
|
101 |
-
#
|
102 |
max_length = text_input_ids.shape[-1]
|
103 |
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
|
104 |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
105 |
|
106 |
-
#
|
107 |
camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
|
108 |
|
109 |
-
#
|
110 |
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
|
111 |
latents = self.prepare_latents(
|
112 |
batch_size,
|
@@ -119,36 +119,33 @@ class SPADPipeline(DiffusionPipeline):
|
|
119 |
generator=None,
|
120 |
)
|
121 |
|
122 |
-
#
|
123 |
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
|
124 |
|
125 |
-
#
|
126 |
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
|
127 |
|
128 |
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
|
129 |
n_objects = 2;
|
130 |
latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
|
131 |
|
132 |
-
#
|
133 |
# self.scheduler.set_timesteps(num_inference_steps)
|
134 |
self.scheduler.set_timesteps(50)
|
135 |
-
#
|
136 |
-
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len,
|
137 |
|
138 |
-
#
|
139 |
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
140 |
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
|
141 |
-
#
|
142 |
for t in tqdm(self.scheduler.timesteps):
|
143 |
-
#
|
144 |
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
|
145 |
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
|
146 |
|
147 |
-
#
|
148 |
context = [
|
149 |
-
# text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
|
150 |
-
# camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
|
151 |
-
# epi_constraint_masks # Keep this as is for now
|
152 |
text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
|
153 |
camera_embeddings, # [n_objects, n_views, 1280]
|
154 |
torch.ones(n_objects, n_views, 6, 32, 32).to(device)
|
@@ -161,22 +158,22 @@ class SPADPipeline(DiffusionPipeline):
|
|
161 |
context=context
|
162 |
)
|
163 |
|
164 |
-
#
|
165 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
166 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
167 |
|
168 |
|
169 |
-
#
|
170 |
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
171 |
|
172 |
# reduce latents
|
173 |
-
#EXPERIMENTAL
|
174 |
latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
|
175 |
|
176 |
-
#
|
177 |
images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
|
178 |
|
179 |
-
#
|
180 |
images = (images / 2 + 0.5).clamp(0, 1)
|
181 |
|
182 |
if images.dim() == 5:
|
|
|
81 |
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
82 |
device = self.device
|
83 |
|
84 |
+
# generate camera batch
|
85 |
if elevations is None or azimuths is None:
|
86 |
elevations = [45] * 4
|
87 |
azimuths = [0, 90, 180, 270]
|
|
|
90 |
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
|
91 |
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
|
92 |
|
93 |
+
# prepare gaussian blob initialization
|
94 |
blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
|
95 |
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
|
96 |
|
97 |
+
# encode text
|
98 |
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
|
99 |
text_embeddings = self.text_encoder(text_input_ids)[0]
|
100 |
|
101 |
+
# prepare unconditional embeddings for classifier-free guidance
|
102 |
max_length = text_input_ids.shape[-1]
|
103 |
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
|
104 |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
105 |
|
106 |
+
# encode camera data
|
107 |
camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
|
108 |
|
109 |
+
# prepare latents
|
110 |
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
|
111 |
latents = self.prepare_latents(
|
112 |
batch_size,
|
|
|
119 |
generator=None,
|
120 |
)
|
121 |
|
122 |
+
# prepare epi_constraint_masks (placeholder- replace with actual implementation later - MIGHT AFFECT PERFORMANCE)
|
123 |
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
|
124 |
|
125 |
+
# prepare plucker embeddings (placeholder, replace with actual implementation - MIGHT AFFECT PERFORMANCE)
|
126 |
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
|
127 |
|
128 |
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
|
129 |
n_objects = 2;
|
130 |
latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
|
131 |
|
132 |
+
# set up scheduler
|
133 |
# self.scheduler.set_timesteps(num_inference_steps)
|
134 |
self.scheduler.set_timesteps(50)
|
135 |
+
# repeat text_embeddings to match the desired dimensions
|
136 |
+
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 768]
|
137 |
|
138 |
+
# reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
|
139 |
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
140 |
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
|
141 |
+
# denoising loop
|
142 |
for t in tqdm(self.scheduler.timesteps):
|
143 |
+
# expand timesteps to match shape [batch_size, 1, 1]
|
144 |
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
|
145 |
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
|
146 |
|
147 |
+
# prepare context
|
148 |
context = [
|
|
|
|
|
|
|
149 |
text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
|
150 |
camera_embeddings, # [n_objects, n_views, 1280]
|
151 |
torch.ones(n_objects, n_views, 6, 32, 32).to(device)
|
|
|
158 |
context=context
|
159 |
)
|
160 |
|
161 |
+
# perform guidance
|
162 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
163 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
164 |
|
165 |
|
166 |
+
# compute previous noisy sample
|
167 |
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
168 |
|
169 |
# reduce latents
|
170 |
+
#EXPERIMENTAL - MIGHT AFFECT PERFORMANCE
|
171 |
latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
|
172 |
|
173 |
+
# decode latents
|
174 |
images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
|
175 |
|
176 |
+
# post-process images
|
177 |
images = (images / 2 + 0.5).clamp(0, 1)
|
178 |
|
179 |
if images.dim() == 5:
|