dynamic aspect support
Browse files- pipeline.py +10 -4
pipeline.py
CHANGED
@@ -96,11 +96,12 @@ class SEGCFGSelfAttnProcessor:
|
|
96 |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
97 |
"""
|
98 |
|
99 |
-
def __init__(self, blur_sigma=1.0, do_cfg=True, inf_blur_threshold=9999.0):
|
100 |
if not hasattr(F, "scaled_dot_product_attention"):
|
101 |
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
102 |
self.blur_sigma = blur_sigma
|
103 |
self.do_cfg = do_cfg
|
|
|
104 |
if self.blur_sigma > inf_blur_threshold:
|
105 |
self.inf_blur = True
|
106 |
else:
|
@@ -157,13 +158,16 @@ class SEGCFGSelfAttnProcessor:
|
|
157 |
head_dim = inner_dim // attn.heads
|
158 |
|
159 |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
160 |
-
|
161 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
162 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
163 |
|
164 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
165 |
# TODO: add support for attn.scale when we move to Torch 2.1
|
166 |
-
height = width = math.isqrt(query.shape[2])
|
|
|
|
|
|
|
|
|
167 |
if self.do_cfg:
|
168 |
query_uncond, query_org, query_ptb = query.chunk(3)
|
169 |
query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//3, attn.heads * head_dim, height, width)
|
@@ -1409,7 +1413,9 @@ class StableDiffusionXLPipeline(
|
|
1409 |
# change attention layer in UNet if use SEG
|
1410 |
if self.do_seg:
|
1411 |
|
1412 |
-
|
|
|
|
|
1413 |
|
1414 |
if self.seg_applied_layers_index:
|
1415 |
drop_layers = self.seg_applied_layers_index
|
|
|
96 |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
97 |
"""
|
98 |
|
99 |
+
def __init__(self, latent_ratio=1.0, blur_sigma=1.0, do_cfg=True, inf_blur_threshold=9999.0):
|
100 |
if not hasattr(F, "scaled_dot_product_attention"):
|
101 |
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
102 |
self.blur_sigma = blur_sigma
|
103 |
self.do_cfg = do_cfg
|
104 |
+
self.latent_ratio = latent_ratio
|
105 |
if self.blur_sigma > inf_blur_threshold:
|
106 |
self.inf_blur = True
|
107 |
else:
|
|
|
158 |
head_dim = inner_dim // attn.heads
|
159 |
|
160 |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
161 |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
162 |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
163 |
|
164 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
165 |
# TODO: add support for attn.scale when we move to Torch 2.1
|
166 |
+
# height = width = math.isqrt(query.shape[2])
|
167 |
+
height = round(self.latent_ratio * math.sqrt(query.shape[2] / self.latent_ratio))
|
168 |
+
width = round(math.sqrt(query.shape[2] / self.latent_ratio))
|
169 |
+
# A = query.shape[2]
|
170 |
+
# height, width = int(self.latent_ratio * torch.sqrt(A/self.latent_ratio)), int(torch.sqrt(A/self.latent_ratio))
|
171 |
if self.do_cfg:
|
172 |
query_uncond, query_org, query_ptb = query.chunk(3)
|
173 |
query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//3, attn.heads * head_dim, height, width)
|
|
|
1413 |
# change attention layer in UNet if use SEG
|
1414 |
if self.do_seg:
|
1415 |
|
1416 |
+
# h/w
|
1417 |
+
latent_ratio = height / width
|
1418 |
+
replace_processor = SEGCFGSelfAttnProcessor(latent_ratio=latent_ratio, blur_sigma=seg_blur_sigma, do_cfg=self.do_classifier_free_guidance)
|
1419 |
|
1420 |
if self.seg_applied_layers_index:
|
1421 |
drop_layers = self.seg_applied_layers_index
|