jadechoghari
commited on
Commit
•
ef48aca
1
Parent(s):
a1c4d22
Update pipeline_spad.py
Browse files- pipeline_spad.py +28 -37
pipeline_spad.py
CHANGED
@@ -12,7 +12,7 @@ from .geometry import get_batch_from_spherical
|
|
12 |
class SPADPipeline(DiffusionPipeline):
|
13 |
def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
|
14 |
super().__init__()
|
15 |
-
|
16 |
self.register_modules(
|
17 |
unet=unet,
|
18 |
vae=vae,
|
@@ -20,26 +20,26 @@ class SPADPipeline(DiffusionPipeline):
|
|
20 |
tokenizer=tokenizer,
|
21 |
scheduler=scheduler
|
22 |
)
|
23 |
-
|
24 |
self.cfg_conds = ["txt", "cam", "epi", "plucker"]
|
25 |
self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
|
26 |
self.use_abs_extrinsics = False
|
27 |
self.use_intrinsic = False
|
28 |
-
|
29 |
self.cc_projection = nn.Sequential(
|
30 |
nn.Linear(4 if not self.use_intrinsic else 8, 1280),
|
31 |
nn.SiLU(),
|
32 |
nn.Linear(1280, 1280),
|
33 |
-
)
|
34 |
nn.init.zeros_(self.cc_projection[-1].weight)
|
35 |
nn.init.zeros_(self.cc_projection[-1].bias)
|
36 |
|
37 |
|
38 |
def generate_camera_batch(self, elevations, azimuths, use_abs=False):
|
39 |
batch = get_batch_from_spherical(elevations, azimuths)
|
40 |
-
|
41 |
abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
|
42 |
-
|
43 |
debug_cams = [[] for _ in range(len(azimuths))]
|
44 |
for i, icam in enumerate(abs_cams):
|
45 |
for j, jcam in enumerate(abs_cams):
|
@@ -49,15 +49,15 @@ class SPADPipeline(DiffusionPipeline):
|
|
49 |
dcam = icam - jcam
|
50 |
dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
|
51 |
debug_cams[i].append(dcam)
|
52 |
-
|
53 |
batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
|
54 |
-
|
55 |
# Add intrinsics to the batch
|
56 |
focal = 1 / np.tan(0.702769935131073 / 2)
|
57 |
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
|
58 |
intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
|
59 |
batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
|
60 |
-
|
61 |
return batch
|
62 |
|
63 |
def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
|
@@ -68,15 +68,15 @@ class SPADPipeline(DiffusionPipeline):
|
|
68 |
if gaussian_blob.max() > 0:
|
69 |
gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
|
70 |
gaussian_blob = 255.0 - gaussian_blob
|
71 |
-
|
72 |
gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
|
73 |
gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
|
74 |
gaussian_blob = torch.from_numpy(gaussian_blob)
|
75 |
-
|
76 |
return gaussian_blob
|
77 |
|
78 |
@torch.no_grad()
|
79 |
-
def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
|
80 |
elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
|
81 |
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
82 |
device = self.device
|
@@ -85,7 +85,7 @@ class SPADPipeline(DiffusionPipeline):
|
|
85 |
if elevations is None or azimuths is None:
|
86 |
elevations = [45] * 4
|
87 |
azimuths = [0, 90, 180, 270]
|
88 |
-
|
89 |
n_views = len(elevations)
|
90 |
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
|
91 |
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
|
@@ -104,7 +104,7 @@ class SPADPipeline(DiffusionPipeline):
|
|
104 |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
105 |
|
106 |
# Encode camera data
|
107 |
-
camera_embeddings = self.cc_projection(camera_batch["cam"])
|
108 |
|
109 |
# Prepare latents
|
110 |
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
|
@@ -127,45 +127,37 @@ class SPADPipeline(DiffusionPipeline):
|
|
127 |
|
128 |
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
|
129 |
n_objects = 2;
|
130 |
-
latents = torch.randn(n_objects, n_views,
|
131 |
|
132 |
# Set up scheduler
|
133 |
# self.scheduler.set_timesteps(num_inference_steps)
|
134 |
-
self.scheduler.set_timesteps(
|
135 |
# Repeat text_embeddings to match the desired dimensions
|
136 |
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
|
137 |
|
138 |
# Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
|
139 |
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
|
|
140 |
# Denoising loop
|
141 |
for t in tqdm(self.scheduler.timesteps):
|
142 |
# Expand timesteps to match shape [batch_size, 1, 1]
|
143 |
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
|
144 |
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
|
145 |
-
|
146 |
-
# # Repeat text_embeddings to match the desired dimensions
|
147 |
-
# text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
|
148 |
-
|
149 |
-
# # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
|
150 |
-
# text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
151 |
-
|
152 |
-
# print("old cam shape: ", camera_embeddings.shape)
|
153 |
-
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
|
154 |
-
# print("cam emb shape: ", camera_embeddings.shape)
|
155 |
# Prepare context
|
156 |
context = [
|
157 |
# text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
|
158 |
# camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
|
159 |
# epi_constraint_masks # Keep this as is for now
|
160 |
-
text_embeddings, # [n_objects, n_views, max_seq_len, 768]
|
161 |
-
camera_embeddings # [n_objects, n_views, 1280]
|
162 |
-
torch.ones(n_objects, n_views, 6, 32, 32)
|
163 |
]
|
164 |
|
165 |
# Predict noise residual
|
166 |
noise_pred = self.unet(
|
167 |
-
latents, # Shape: [batch_size, 1, 4, 64, 64]
|
168 |
-
timesteps=timesteps, # Shape: [batch_size, 1, 1]
|
169 |
context=context
|
170 |
)
|
171 |
|
@@ -179,19 +171,18 @@ class SPADPipeline(DiffusionPipeline):
|
|
179 |
|
180 |
# reduce latents
|
181 |
#EXPERIMENTAL
|
182 |
-
|
183 |
-
|
184 |
-
latents = latents.view(-1, latents.shape[2], latents.shape[3], latents.shape[4])
|
185 |
# Decode latents
|
186 |
-
images = self.vae.decode(
|
187 |
|
188 |
# Post-process images
|
189 |
images = (images / 2 + 0.5).clamp(0, 1)
|
190 |
|
191 |
if images.dim() == 5:
|
192 |
-
|
193 |
elif images.dim() == 4:
|
194 |
-
|
195 |
else:
|
196 |
raise ValueError(f"Unexpected image dimensions: {images.shape}")
|
197 |
|
|
|
12 |
class SPADPipeline(DiffusionPipeline):
|
13 |
def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
|
14 |
super().__init__()
|
15 |
+
|
16 |
self.register_modules(
|
17 |
unet=unet,
|
18 |
vae=vae,
|
|
|
20 |
tokenizer=tokenizer,
|
21 |
scheduler=scheduler
|
22 |
)
|
23 |
+
|
24 |
self.cfg_conds = ["txt", "cam", "epi", "plucker"]
|
25 |
self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
|
26 |
self.use_abs_extrinsics = False
|
27 |
self.use_intrinsic = False
|
28 |
+
|
29 |
self.cc_projection = nn.Sequential(
|
30 |
nn.Linear(4 if not self.use_intrinsic else 8, 1280),
|
31 |
nn.SiLU(),
|
32 |
nn.Linear(1280, 1280),
|
33 |
+
).to(device)
|
34 |
nn.init.zeros_(self.cc_projection[-1].weight)
|
35 |
nn.init.zeros_(self.cc_projection[-1].bias)
|
36 |
|
37 |
|
38 |
def generate_camera_batch(self, elevations, azimuths, use_abs=False):
|
39 |
batch = get_batch_from_spherical(elevations, azimuths)
|
40 |
+
|
41 |
abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
|
42 |
+
|
43 |
debug_cams = [[] for _ in range(len(azimuths))]
|
44 |
for i, icam in enumerate(abs_cams):
|
45 |
for j, jcam in enumerate(abs_cams):
|
|
|
49 |
dcam = icam - jcam
|
50 |
dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
|
51 |
debug_cams[i].append(dcam)
|
52 |
+
|
53 |
batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
|
54 |
+
|
55 |
# Add intrinsics to the batch
|
56 |
focal = 1 / np.tan(0.702769935131073 / 2)
|
57 |
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
|
58 |
intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
|
59 |
batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
|
60 |
+
|
61 |
return batch
|
62 |
|
63 |
def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
|
|
|
68 |
if gaussian_blob.max() > 0:
|
69 |
gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
|
70 |
gaussian_blob = 255.0 - gaussian_blob
|
71 |
+
|
72 |
gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
|
73 |
gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
|
74 |
gaussian_blob = torch.from_numpy(gaussian_blob)
|
75 |
+
|
76 |
return gaussian_blob
|
77 |
|
78 |
@torch.no_grad()
|
79 |
+
def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
|
80 |
elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
|
81 |
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
82 |
device = self.device
|
|
|
85 |
if elevations is None or azimuths is None:
|
86 |
elevations = [45] * 4
|
87 |
azimuths = [0, 90, 180, 270]
|
88 |
+
|
89 |
n_views = len(elevations)
|
90 |
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
|
91 |
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
|
|
|
104 |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
105 |
|
106 |
# Encode camera data
|
107 |
+
camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
|
108 |
|
109 |
# Prepare latents
|
110 |
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
|
|
|
127 |
|
128 |
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
|
129 |
n_objects = 2;
|
130 |
+
latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
|
131 |
|
132 |
# Set up scheduler
|
133 |
# self.scheduler.set_timesteps(num_inference_steps)
|
134 |
+
self.scheduler.set_timesteps(50)
|
135 |
# Repeat text_embeddings to match the desired dimensions
|
136 |
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
|
137 |
|
138 |
# Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
|
139 |
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
140 |
+
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
|
141 |
# Denoising loop
|
142 |
for t in tqdm(self.scheduler.timesteps):
|
143 |
# Expand timesteps to match shape [batch_size, 1, 1]
|
144 |
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
|
145 |
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
|
146 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
# Prepare context
|
148 |
context = [
|
149 |
# text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
|
150 |
# camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
|
151 |
# epi_constraint_masks # Keep this as is for now
|
152 |
+
text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
|
153 |
+
camera_embeddings, # [n_objects, n_views, 1280]
|
154 |
+
torch.ones(n_objects, n_views, 6, 32, 32).to(device)
|
155 |
]
|
156 |
|
157 |
# Predict noise residual
|
158 |
noise_pred = self.unet(
|
159 |
+
latents.to(device), # Shape: [batch_size, 1, 4, 64, 64]
|
160 |
+
timesteps=timesteps.to(device), # Shape: [batch_size, 1, 1]
|
161 |
context=context
|
162 |
)
|
163 |
|
|
|
171 |
|
172 |
# reduce latents
|
173 |
#EXPERIMENTAL
|
174 |
+
latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
|
175 |
+
|
|
|
176 |
# Decode latents
|
177 |
+
images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
|
178 |
|
179 |
# Post-process images
|
180 |
images = (images / 2 + 0.5).clamp(0, 1)
|
181 |
|
182 |
if images.dim() == 5:
|
183 |
+
images_output = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors
|
184 |
elif images.dim() == 4:
|
185 |
+
images_output = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors
|
186 |
else:
|
187 |
raise ValueError(f"Unexpected image dimensions: {images.shape}")
|
188 |
|