nyanko7 commited on
Commit
0547714
1 Parent(s): 5d81c4c

dynamic aspect support

Browse files
Files changed (1) hide show
  1. pipeline.py +10 -4
pipeline.py CHANGED
@@ -96,11 +96,12 @@ class SEGCFGSelfAttnProcessor:
96
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
97
  """
98
 
99
- def __init__(self, blur_sigma=1.0, do_cfg=True, inf_blur_threshold=9999.0):
100
  if not hasattr(F, "scaled_dot_product_attention"):
101
  raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
102
  self.blur_sigma = blur_sigma
103
  self.do_cfg = do_cfg
 
104
  if self.blur_sigma > inf_blur_threshold:
105
  self.inf_blur = True
106
  else:
@@ -157,13 +158,16 @@ class SEGCFGSelfAttnProcessor:
157
  head_dim = inner_dim // attn.heads
158
 
159
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
160
-
161
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
162
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
163
 
164
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
165
  # TODO: add support for attn.scale when we move to Torch 2.1
166
- height = width = math.isqrt(query.shape[2])
 
 
 
 
167
  if self.do_cfg:
168
  query_uncond, query_org, query_ptb = query.chunk(3)
169
  query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//3, attn.heads * head_dim, height, width)
@@ -1409,7 +1413,9 @@ class StableDiffusionXLPipeline(
1409
  # change attention layer in UNet if use SEG
1410
  if self.do_seg:
1411
 
1412
- replace_processor = SEGCFGSelfAttnProcessor(blur_sigma=seg_blur_sigma, do_cfg=self.do_classifier_free_guidance)
 
 
1413
 
1414
  if self.seg_applied_layers_index:
1415
  drop_layers = self.seg_applied_layers_index
 
96
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
97
  """
98
 
99
+ def __init__(self, latent_ratio=1.0, blur_sigma=1.0, do_cfg=True, inf_blur_threshold=9999.0):
100
  if not hasattr(F, "scaled_dot_product_attention"):
101
  raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
102
  self.blur_sigma = blur_sigma
103
  self.do_cfg = do_cfg
104
+ self.latent_ratio = latent_ratio
105
  if self.blur_sigma > inf_blur_threshold:
106
  self.inf_blur = True
107
  else:
 
158
  head_dim = inner_dim // attn.heads
159
 
160
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
 
161
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
162
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
163
 
164
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
165
  # TODO: add support for attn.scale when we move to Torch 2.1
166
+ # height = width = math.isqrt(query.shape[2])
167
+ height = round(self.latent_ratio * math.sqrt(query.shape[2] / self.latent_ratio))
168
+ width = round(math.sqrt(query.shape[2] / self.latent_ratio))
169
+ # A = query.shape[2]
170
+ # height, width = int(self.latent_ratio * torch.sqrt(A/self.latent_ratio)), int(torch.sqrt(A/self.latent_ratio))
171
  if self.do_cfg:
172
  query_uncond, query_org, query_ptb = query.chunk(3)
173
  query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//3, attn.heads * head_dim, height, width)
 
1413
  # change attention layer in UNet if use SEG
1414
  if self.do_seg:
1415
 
1416
+ # h/w
1417
+ latent_ratio = height / width
1418
+ replace_processor = SEGCFGSelfAttnProcessor(latent_ratio=latent_ratio, blur_sigma=seg_blur_sigma, do_cfg=self.do_classifier_free_guidance)
1419
 
1420
  if self.seg_applied_layers_index:
1421
  drop_layers = self.seg_applied_layers_index