jadechoghari
commited on
Commit
•
0ea7028
1
Parent(s):
3db57e8
Update unet/mv_unet.py
Browse files- unet/mv_unet.py +10 -1
unet/mv_unet.py
CHANGED
@@ -204,11 +204,14 @@ class SPADUnetModel(UNetModel, ModelMixin, ConfigMixin):
|
|
204 |
timesteps = rearrange(timesteps, "n v -> (n v)")
|
205 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
206 |
time = self.time_embed(t_emb)
|
|
|
207 |
time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)
|
|
|
208 |
|
209 |
# extract txt and cam embedding (absolute) from context
|
210 |
if len(context) == 2:
|
211 |
txt, cam = context
|
|
|
212 |
elif len(context) == 3:
|
213 |
txt, cam, epi_mask = context
|
214 |
txt = (txt, epi_mask)
|
@@ -219,13 +222,19 @@ class SPADUnetModel(UNetModel, ModelMixin, ConfigMixin):
|
|
219 |
if x.shape[2] > 4:
|
220 |
plucker, x = x[:, :, 4:], x[:, :, :4]
|
221 |
txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)
|
|
|
222 |
|
|
|
223 |
# combine timestep and camera embedding (resnet)
|
224 |
-
time_cam = time + cam
|
225 |
del time, cam
|
226 |
|
227 |
# encode
|
|
|
228 |
h = x.type(self.dtype)
|
|
|
|
|
|
|
229 |
hs = self.encode(h, time_cam, txt, self.input_blocks)
|
230 |
|
231 |
# middle block
|
|
|
204 |
timesteps = rearrange(timesteps, "n v -> (n v)")
|
205 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
206 |
time = self.time_embed(t_emb)
|
207 |
+
print("old time: ", time.shape)
|
208 |
time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)
|
209 |
+
# 2, 4, 1280
|
210 |
|
211 |
# extract txt and cam embedding (absolute) from context
|
212 |
if len(context) == 2:
|
213 |
txt, cam = context
|
214 |
+
print("txt shape", txt.shape)
|
215 |
elif len(context) == 3:
|
216 |
txt, cam, epi_mask = context
|
217 |
txt = (txt, epi_mask)
|
|
|
222 |
if x.shape[2] > 4:
|
223 |
plucker, x = x[:, :, 4:], x[:, :, :4]
|
224 |
txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)
|
225 |
+
print("extracted")
|
226 |
|
227 |
+
# print("txt shape: ", txt.shape)
|
228 |
# combine timestep and camera embedding (resnet)
|
229 |
+
time_cam = time # add + cam later
|
230 |
del time, cam
|
231 |
|
232 |
# encode
|
233 |
+
|
234 |
h = x.type(self.dtype)
|
235 |
+
print("h: ", h.shape)
|
236 |
+
print("time_cam: ", time_cam.shape)
|
237 |
+
# print("txt: ", txt.shape)
|
238 |
hs = self.encode(h, time_cam, txt, self.input_blocks)
|
239 |
|
240 |
# middle block
|