jadechoghari commited on
Commit
0ea7028
1 Parent(s): 3db57e8

Update unet/mv_unet.py

Browse files
Files changed (1) hide show
  1. unet/mv_unet.py +10 -1
unet/mv_unet.py CHANGED
@@ -204,11 +204,14 @@ class SPADUnetModel(UNetModel, ModelMixin, ConfigMixin):
204
  timesteps = rearrange(timesteps, "n v -> (n v)")
205
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
206
  time = self.time_embed(t_emb)
 
207
  time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)
 
208
 
209
  # extract txt and cam embedding (absolute) from context
210
  if len(context) == 2:
211
  txt, cam = context
 
212
  elif len(context) == 3:
213
  txt, cam, epi_mask = context
214
  txt = (txt, epi_mask)
@@ -219,13 +222,19 @@ class SPADUnetModel(UNetModel, ModelMixin, ConfigMixin):
219
  if x.shape[2] > 4:
220
  plucker, x = x[:, :, 4:], x[:, :, :4]
221
  txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)
 
222
 
 
223
  # combine timestep and camera embedding (resnet)
224
- time_cam = time + cam
225
  del time, cam
226
 
227
  # encode
 
228
  h = x.type(self.dtype)
 
 
 
229
  hs = self.encode(h, time_cam, txt, self.input_blocks)
230
 
231
  # middle block
 
204
  timesteps = rearrange(timesteps, "n v -> (n v)")
205
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
206
  time = self.time_embed(t_emb)
207
+ print("old time: ", time.shape)
208
  time = rearrange(time, "(n v) d -> n v d", n=n_objects, v=n_views)
209
+ # 2, 4, 1280
210
 
211
  # extract txt and cam embedding (absolute) from context
212
  if len(context) == 2:
213
  txt, cam = context
214
+ print("txt shape", txt.shape)
215
  elif len(context) == 3:
216
  txt, cam, epi_mask = context
217
  txt = (txt, epi_mask)
 
222
  if x.shape[2] > 4:
223
  plucker, x = x[:, :, 4:], x[:, :, :4]
224
  txt = (*txt, plucker) if isinstance(txt, tuple) else (txt, plucker)
225
+ print("extracted")
226
 
227
+ # print("txt shape: ", txt.shape)
228
  # combine timestep and camera embedding (resnet)
229
+ time_cam = time # add + cam later
230
  del time, cam
231
 
232
  # encode
233
+
234
  h = x.type(self.dtype)
235
+ print("h: ", h.shape)
236
+ print("time_cam: ", time_cam.shape)
237
+ # print("txt: ", txt.shape)
238
  hs = self.encode(h, time_cam, txt, self.input_blocks)
239
 
240
  # middle block