Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
@@ -1587,7 +1587,7 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
|
|
1587 |
|
1588 |
bsz, num_drags, drag_dim = drags.shape
|
1589 |
assert num_drags == self.num_drags
|
1590 |
-
if (self.
|
1591 |
if force_drop_ids is None:
|
1592 |
drop_ids = torch.rand(bsz, device=x_cond_extra.device) < self.drag_dropout_prob
|
1593 |
else:
|
|
|
1587 |
|
1588 |
bsz, num_drags, drag_dim = drags.shape
|
1589 |
assert num_drags == self.num_drags
|
1590 |
+
if (self.training and self.drag_dropout_prob > 0) or force_drop_ids is not None:
|
1591 |
if force_drop_ids is None:
|
1592 |
drop_ids = torch.rand(bsz, device=x_cond_extra.device) < self.drag_dropout_prob
|
1593 |
else:
|