NatanBagrov
commited on
Commit
•
10c31ce
1
Parent(s):
d204b62
added option to skip mid block (#5)
Browse files- added option to skip mid block (b84b5c4088d29603d9c765d32549bb39b23231ed)
- pipeline.py +28 -12
pipeline.py
CHANGED
@@ -52,6 +52,19 @@ def custom_sort_order(obj):
|
|
52 |
return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
|
56 |
configurations = FlexibleUnetConfigurations
|
57 |
|
@@ -105,18 +118,21 @@ class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
|
|
105 |
mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
|
106 |
mid_num_attentions = self.configurations.get("mid_num_attentions")
|
107 |
mid_num_resnets = self.configurations.get("mid_num_resnets")
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
###############
|
122 |
# Up blocks #
|
|
|
52 |
return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
|
53 |
|
54 |
|
55 |
+
class FlexibleIdentityBlock(nn.Module):
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
hidden_states: torch.FloatTensor,
|
59 |
+
temb: Optional[torch.FloatTensor] = None,
|
60 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
61 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
62 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
63 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
64 |
+
):
|
65 |
+
return hidden_states
|
66 |
+
|
67 |
+
|
68 |
class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
|
69 |
configurations = FlexibleUnetConfigurations
|
70 |
|
|
|
118 |
mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
|
119 |
mid_num_attentions = self.configurations.get("mid_num_attentions")
|
120 |
mid_num_resnets = self.configurations.get("mid_num_resnets")
|
121 |
+
|
122 |
+
if mid_num_resnets == mid_num_attentions == 0:
|
123 |
+
self.mid_block = FlexibleIdentityBlock()
|
124 |
+
else:
|
125 |
+
self.mid_block = FlexibleUNetMidBlock2DCrossAttn(in_channels=down_blocks_out_channels[-1],
|
126 |
+
temb_channels=temb_dim,
|
127 |
+
resnet_act_fn=resnet_act_fn,
|
128 |
+
resnet_eps=resnet_eps,
|
129 |
+
cross_attention_dim=cross_attention_dim,
|
130 |
+
num_attention_heads=num_attention_heads,
|
131 |
+
num_resnets=mid_num_resnets,
|
132 |
+
num_attentions=mid_num_attentions,
|
133 |
+
mix_block_in_forward=mix_block_in_forward,
|
134 |
+
add_upsample=mid_block_add_upsample
|
135 |
+
)
|
136 |
|
137 |
###############
|
138 |
# Up blocks #
|