jadechoghari
commited on
Commit
•
4cdc52f
1
Parent(s):
c3fcadd
Update unet/mv_attention.py
Browse files- unet/mv_attention.py +2 -2
unet/mv_attention.py
CHANGED
@@ -18,7 +18,7 @@ def conv_nd(dims, *args, **kwargs):
|
|
18 |
raise ValueError(f"unsupported dimensions: {dims}")
|
19 |
|
20 |
|
21 |
-
from
|
22 |
|
23 |
try:
|
24 |
import xformers
|
@@ -364,4 +364,4 @@ if __name__ == "__main__":
|
|
364 |
).cuda(),
|
365 |
torch.randn(n_objects, n_views, 6, 32, 32).cuda(),
|
366 |
]
|
367 |
-
x_post = spt_post(x, context=context)
|
|
|
18 |
raise ValueError(f"unsupported dimensions: {dims}")
|
19 |
|
20 |
|
21 |
+
from attention import *
|
22 |
|
23 |
try:
|
24 |
import xformers
|
|
|
364 |
).cuda(),
|
365 |
torch.randn(n_objects, n_views, 6, 32, 32).cuda(),
|
366 |
]
|
367 |
+
x_post = spt_post(x, context=context)
|