Spaces:
Runtime error
Runtime error
handle tuple / list for resolutions
Browse files
audiodiffusion/__init__.py
CHANGED
@@ -213,11 +213,12 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
|
213 |
step_generator = step_generator or generator
|
214 |
# For backwards compatibility
|
215 |
if type(self.unet.sample_size) == int:
|
216 |
-
self.unet.sample_size =
|
217 |
-
self.unet.sample_size
|
218 |
if noise is None:
|
219 |
noise = torch.randn(
|
220 |
-
|
|
|
221 |
generator=generator)
|
222 |
images = noise
|
223 |
mask = None
|
|
|
213 |
step_generator = step_generator or generator
|
214 |
# For backwards compatibility
|
215 |
if type(self.unet.sample_size) == int:
|
216 |
+
self.unet.sample_size = (self.unet.sample_size,
|
217 |
+
self.unet.sample_size)
|
218 |
if noise is None:
|
219 |
noise = torch.randn(
|
220 |
+
(batch_size, self.unet.in_channels, self.unet.sample_size[0],
|
221 |
+
self.unet.sample_size[1]),
|
222 |
generator=generator)
|
223 |
images = noise
|
224 |
mask = None
|