Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
6c1d070
1
Parent(s):
023c7dd
support higher res/lower res sampling than training time
Browse files
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
@@ -379,7 +379,8 @@ class NCSNpp(nn.Module):
|
|
379 |
#print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
|
380 |
h = modules[m_idx](hs[-1], temb, zemb)
|
381 |
m_idx += 1
|
382 |
-
if
|
|
|
383 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
384 |
h = modules[m_idx](h, cond, cond_mask)
|
385 |
else:
|
@@ -415,6 +416,7 @@ class NCSNpp(nn.Module):
|
|
415 |
h = hs[-1]
|
416 |
h = modules[m_idx](h, temb, zemb)
|
417 |
m_idx += 1
|
|
|
418 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
419 |
h = modules[m_idx](h, cond, cond_mask)
|
420 |
else:
|
@@ -431,7 +433,8 @@ class NCSNpp(nn.Module):
|
|
431 |
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb, zemb)
|
432 |
m_idx += 1
|
433 |
|
434 |
-
if h.shape[-1] in self.attn_resolutions:
|
|
|
435 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
436 |
h = modules[m_idx](h, cond, cond_mask)
|
437 |
else:
|
|
|
379 |
#print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
|
380 |
h = modules[m_idx](hs[-1], temb, zemb)
|
381 |
m_idx += 1
|
382 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock, layers.AttnBlock):
|
383 |
+
#if h.shape[-1] in self.attn_resolutions:
|
384 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
385 |
h = modules[m_idx](h, cond, cond_mask)
|
386 |
else:
|
|
|
416 |
h = hs[-1]
|
417 |
h = modules[m_idx](h, temb, zemb)
|
418 |
m_idx += 1
|
419 |
+
|
420 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
421 |
h = modules[m_idx](h, cond, cond_mask)
|
422 |
else:
|
|
|
433 |
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb, zemb)
|
434 |
m_idx += 1
|
435 |
|
436 |
+
#if h.shape[-1] in self.attn_resolutions:
|
437 |
+
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock, layers.AttnBlock):
|
438 |
if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
|
439 |
h = modules[m_idx](h, cond, cond_mask)
|
440 |
else:
|