skytnt commited on
Commit
7a5efb4
1 Parent(s): ef85cd3

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +149 -138
pipeline.py CHANGED
@@ -1,10 +1,9 @@
1
  import inspect
2
  import re
3
- from typing import Callable, List, Optional, Union
4
  import PIL
5
  import numpy as np
6
  import torch
7
-
8
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
9
 
10
  from diffusers.configuration_utils import FrozenDict
@@ -17,7 +16,8 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
17
 
18
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
 
20
- re_attention = re.compile(r"""
 
21
  \\\(|
22
  \\\)|
23
  \\\[|
@@ -31,7 +31,9 @@ re_attention = re.compile(r"""
31
  ]|
32
  [^\\()\[\]:]+|
33
  :
34
- """, re.X)
 
 
35
 
36
 
37
  def parse_prompt_attention(text):
@@ -84,17 +86,17 @@ def parse_prompt_attention(text):
84
  text = m.group(0)
85
  weight = m.group(1)
86
 
87
- if text.startswith('\\'):
88
  res.append([text[1:], 1.0])
89
- elif text == '(':
90
  round_brackets.append(len(res))
91
- elif text == '[':
92
  square_brackets.append(len(res))
93
  elif weight is not None and len(round_brackets) > 0:
94
  multiply_range(round_brackets.pop(), float(weight))
95
- elif text == ')' and len(round_brackets) > 0:
96
  multiply_range(round_brackets.pop(), round_bracket_multiplier)
97
- elif text == ']' and len(square_brackets) > 0:
98
  multiply_range(square_brackets.pop(), square_bracket_multiplier)
99
  else:
100
  res.append([text, 1.0])
@@ -120,11 +122,7 @@ def parse_prompt_attention(text):
120
  return res
121
 
122
 
123
- def get_prompts_with_weights(
124
- pipe: DiffusionPipeline,
125
- prompt: List[str],
126
- max_length: int
127
- ):
128
  r"""
129
  Tokenize a list of prompts and return its tokens with weights of each token.
130
 
@@ -158,9 +156,7 @@ def get_prompts_with_weights(
158
  return tokens, weights
159
 
160
 
161
- def pad_tokens_and_weights(tokens, weights, max_length, bos, eos,
162
- no_boseos_middle=True,
163
- chunk_length=77):
164
  r"""
165
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
166
  """
@@ -169,27 +165,24 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos,
169
  for i in range(len(tokens)):
170
  tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
171
  if no_boseos_middle:
172
- weights[i] = [1.] + weights[i] + [1.] * (max_length - 1 - len(weights[i]))
173
  else:
174
  w = []
175
  if len(weights[i]) == 0:
176
- w = [1.] * weights_length
177
  else:
178
  for j in range((len(weights[i]) - 1) // chunk_length + 1):
179
- w.append(1.) # weight for starting token in this chunk
180
- w += weights[i][j * chunk_length: min(len(weights[i]), (j + 1) * chunk_length)]
181
- w.append(1.) # weight for ending token in this chunk
182
- w += [1.] * (weights_length - len(w))
183
  weights[i] = w[:]
184
 
185
  return tokens, weights
186
 
187
 
188
  def get_unweighted_text_embeddings(
189
- pipe: DiffusionPipeline,
190
- text_input: torch.Tensor,
191
- chunk_length: int,
192
- no_boseos_middle: Optional[bool] = True
193
  ):
194
  """
195
  When the length of tokens is a multiple of the capacity of the text encoder,
@@ -200,7 +193,7 @@ def get_unweighted_text_embeddings(
200
  text_embeddings = []
201
  for i in range(max_embeddings_multiples):
202
  # extract the i-th chunk
203
- text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1) * (chunk_length - 2) + 2].clone()
204
 
205
  # cover the head and the tail by the starting and the ending tokens
206
  text_input_chunk[:, 0] = text_input[0, 0]
@@ -226,14 +219,14 @@ def get_unweighted_text_embeddings(
226
 
227
 
228
  def get_weighted_text_embeddings(
229
- pipe: DiffusionPipeline,
230
- prompt: Union[str, List[str]],
231
- uncond_prompt: Optional[Union[str, List[str]]] = None,
232
- max_embeddings_multiples: Optional[int] = 1,
233
- no_boseos_middle: Optional[bool] = False,
234
- skip_parsing: Optional[bool] = False,
235
- skip_weighting: Optional[bool] = False,
236
- **kwargs
237
  ):
238
  r"""
239
  Prompts can be assigned with local weights using brackets. For example,
@@ -271,46 +264,64 @@ def get_weighted_text_embeddings(
271
  uncond_prompt = [uncond_prompt]
272
  uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
273
  else:
274
- prompt_tokens = [token[1:-1] for token in
275
- pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
276
- prompt_weights = [[1.] * len(token) for token in prompt_tokens]
 
277
  if uncond_prompt is not None:
278
  if isinstance(uncond_prompt, str):
279
  uncond_prompt = [uncond_prompt]
280
- uncond_tokens = [token[1:-1] for token in
281
- pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids]
282
- uncond_weights = [[1.] * len(token) for token in uncond_tokens]
 
 
283
 
284
  # round up the longest length of tokens to a multiple of (model_max_length - 2)
285
  max_length = max([len(token) for token in prompt_tokens])
286
  if uncond_prompt is not None:
287
  max_length = max(max_length, max([len(token) for token in uncond_tokens]))
288
 
289
- max_embeddings_multiples = min(max_embeddings_multiples,
290
- (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1)
 
291
  max_embeddings_multiples = max(1, max_embeddings_multiples)
292
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
293
 
294
  # pad the length of tokens and weights
295
  bos = pipe.tokenizer.bos_token_id
296
  eos = pipe.tokenizer.eos_token_id
297
- prompt_tokens, prompt_weights = pad_tokens_and_weights(prompt_tokens, prompt_weights, max_length, bos, eos,
298
- no_boseos_middle=no_boseos_middle,
299
- chunk_length=pipe.tokenizer.model_max_length)
 
 
 
 
 
 
300
  prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
301
  if uncond_prompt is not None:
302
- uncond_tokens, uncond_weights = pad_tokens_and_weights(uncond_tokens, uncond_weights, max_length, bos, eos,
303
- no_boseos_middle=no_boseos_middle,
304
- chunk_length=pipe.tokenizer.model_max_length)
 
 
 
 
 
 
305
  uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
306
 
307
  # get the embeddings
308
- text_embeddings = get_unweighted_text_embeddings(pipe, prompt_tokens, pipe.tokenizer.model_max_length,
309
- no_boseos_middle=no_boseos_middle)
 
310
  prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
311
  if uncond_prompt is not None:
312
- uncond_embeddings = get_unweighted_text_embeddings(pipe, uncond_tokens, pipe.tokenizer.model_max_length,
313
- no_boseos_middle=no_boseos_middle)
 
314
  uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
315
 
316
  # assign weights to the prompts and normalize in the sense of mean
@@ -382,14 +393,14 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
382
  """
383
 
384
  def __init__(
385
- self,
386
- vae: AutoencoderKL,
387
- text_encoder: CLIPTextModel,
388
- tokenizer: CLIPTokenizer,
389
- unet: UNet2DConditionModel,
390
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
391
- safety_checker: StableDiffusionSafetyChecker,
392
- feature_extractor: CLIPFeatureExtractor,
393
  ):
394
  super().__init__()
395
 
@@ -456,26 +467,26 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
456
 
457
  @torch.no_grad()
458
  def __call__(
459
- self,
460
- prompt: Union[str, List[str]],
461
- negative_prompt: Optional[Union[str, List[str]]] = None,
462
- init_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
463
- mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
464
- height: int = 512,
465
- width: int = 512,
466
- num_inference_steps: int = 50,
467
- guidance_scale: float = 7.5,
468
- strength: float = 0.8,
469
- num_images_per_prompt: Optional[int] = 1,
470
- eta: float = 0.0,
471
- generator: Optional[torch.Generator] = None,
472
- latents: Optional[torch.FloatTensor] = None,
473
- max_embeddings_multiples: Optional[int] = 3,
474
- output_type: Optional[str] = "pil",
475
- return_dict: bool = True,
476
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
477
- callback_steps: Optional[int] = 1,
478
- **kwargs,
479
  ):
480
  r"""
481
  Function invoked when calling the pipeline for generation.
@@ -563,7 +574,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
563
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
564
 
565
  if (callback_steps is None) or (
566
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
567
  ):
568
  raise ValueError(
569
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
@@ -588,12 +599,12 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
588
  " the batch size of `prompt`."
589
  )
590
 
591
- text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
592
  pipe=self,
593
  prompt=prompt,
594
  uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
595
  max_embeddings_multiples=max_embeddings_multiples,
596
- **kwargs
597
  )
598
  bs_embed, seq_len, _ = text_embeddings.shape
599
  text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
@@ -742,23 +753,23 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
742
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
743
 
744
  def text2img(
745
- self,
746
- prompt: Union[str, List[str]],
747
- negative_prompt: Optional[Union[str, List[str]]] = None,
748
- height: int = 512,
749
- width: int = 512,
750
- num_inference_steps: int = 50,
751
- guidance_scale: float = 7.5,
752
- num_images_per_prompt: Optional[int] = 1,
753
- eta: float = 0.0,
754
- generator: Optional[torch.Generator] = None,
755
- latents: Optional[torch.FloatTensor] = None,
756
- max_embeddings_multiples: Optional[int] = 3,
757
- output_type: Optional[str] = "pil",
758
- return_dict: bool = True,
759
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
760
- callback_steps: Optional[int] = 1,
761
- **kwargs,
762
  ):
763
  r"""
764
  Function for text-to-image generation.
@@ -830,26 +841,26 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
830
  return_dict=return_dict,
831
  callback=callback,
832
  callback_steps=callback_steps,
833
- **kwargs
834
  )
835
 
836
  def img2img(
837
- self,
838
- init_image: Union[torch.FloatTensor, PIL.Image.Image],
839
- prompt: Union[str, List[str]],
840
- negative_prompt: Optional[Union[str, List[str]]] = None,
841
- strength: float = 0.8,
842
- num_inference_steps: Optional[int] = 50,
843
- guidance_scale: Optional[float] = 7.5,
844
- num_images_per_prompt: Optional[int] = 1,
845
- eta: Optional[float] = 0.0,
846
- generator: Optional[torch.Generator] = None,
847
- max_embeddings_multiples: Optional[int] = 3,
848
- output_type: Optional[str] = "pil",
849
- return_dict: bool = True,
850
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
851
- callback_steps: Optional[int] = 1,
852
- **kwargs,
853
  ):
854
  r"""
855
  Function for image-to-image generation.
@@ -921,27 +932,27 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
921
  return_dict=return_dict,
922
  callback=callback,
923
  callback_steps=callback_steps,
924
- **kwargs
925
  )
926
 
927
  def inpaint(
928
- self,
929
- init_image: Union[torch.FloatTensor, PIL.Image.Image],
930
- mask_image: Union[torch.FloatTensor, PIL.Image.Image],
931
- prompt: Union[str, List[str]],
932
- negative_prompt: Optional[Union[str, List[str]]] = None,
933
- strength: float = 0.8,
934
- num_inference_steps: Optional[int] = 50,
935
- guidance_scale: Optional[float] = 7.5,
936
- num_images_per_prompt: Optional[int] = 1,
937
- eta: Optional[float] = 0.0,
938
- generator: Optional[torch.Generator] = None,
939
- max_embeddings_multiples: Optional[int] = 3,
940
- output_type: Optional[str] = "pil",
941
- return_dict: bool = True,
942
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
943
- callback_steps: Optional[int] = 1,
944
- **kwargs,
945
  ):
946
  r"""
947
  Function for inpaint.
@@ -1018,5 +1029,5 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1018
  return_dict=return_dict,
1019
  callback=callback,
1020
  callback_steps=callback_steps,
1021
- **kwargs
1022
  )
 
1
  import inspect
2
  import re
 
3
  import PIL
4
  import numpy as np
5
  import torch
6
+ from typing import Callable, List, Optional, Union
7
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
8
 
9
  from diffusers.configuration_utils import FrozenDict
 
16
 
17
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
 
19
+ re_attention = re.compile(
20
+ r"""
21
  \\\(|
22
  \\\)|
23
  \\\[|
 
31
  ]|
32
  [^\\()\[\]:]+|
33
  :
34
+ """,
35
+ re.X,
36
+ )
37
 
38
 
39
  def parse_prompt_attention(text):
 
86
  text = m.group(0)
87
  weight = m.group(1)
88
 
89
+ if text.startswith("\\"):
90
  res.append([text[1:], 1.0])
91
+ elif text == "(":
92
  round_brackets.append(len(res))
93
+ elif text == "[":
94
  square_brackets.append(len(res))
95
  elif weight is not None and len(round_brackets) > 0:
96
  multiply_range(round_brackets.pop(), float(weight))
97
+ elif text == ")" and len(round_brackets) > 0:
98
  multiply_range(round_brackets.pop(), round_bracket_multiplier)
99
+ elif text == "]" and len(square_brackets) > 0:
100
  multiply_range(square_brackets.pop(), square_bracket_multiplier)
101
  else:
102
  res.append([text, 1.0])
 
122
  return res
123
 
124
 
125
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
 
 
 
 
126
  r"""
127
  Tokenize a list of prompts and return its tokens with weights of each token.
128
 
 
156
  return tokens, weights
157
 
158
 
159
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
 
 
160
  r"""
161
  Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
162
  """
 
165
  for i in range(len(tokens)):
166
  tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
167
  if no_boseos_middle:
168
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
169
  else:
170
  w = []
171
  if len(weights[i]) == 0:
172
+ w = [1.0] * weights_length
173
  else:
174
  for j in range((len(weights[i]) - 1) // chunk_length + 1):
175
+ w.append(1.0) # weight for starting token in this chunk
176
+ w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
177
+ w.append(1.0) # weight for ending token in this chunk
178
+ w += [1.0] * (weights_length - len(w))
179
  weights[i] = w[:]
180
 
181
  return tokens, weights
182
 
183
 
184
  def get_unweighted_text_embeddings(
185
+ pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
 
 
 
186
  ):
187
  """
188
  When the length of tokens is a multiple of the capacity of the text encoder,
 
193
  text_embeddings = []
194
  for i in range(max_embeddings_multiples):
195
  # extract the i-th chunk
196
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
197
 
198
  # cover the head and the tail by the starting and the ending tokens
199
  text_input_chunk[:, 0] = text_input[0, 0]
 
219
 
220
 
221
  def get_weighted_text_embeddings(
222
+ pipe: DiffusionPipeline,
223
+ prompt: Union[str, List[str]],
224
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
225
+ max_embeddings_multiples: Optional[int] = 1,
226
+ no_boseos_middle: Optional[bool] = False,
227
+ skip_parsing: Optional[bool] = False,
228
+ skip_weighting: Optional[bool] = False,
229
+ **kwargs,
230
  ):
231
  r"""
232
  Prompts can be assigned with local weights using brackets. For example,
 
264
  uncond_prompt = [uncond_prompt]
265
  uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
266
  else:
267
+ prompt_tokens = [
268
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
269
+ ]
270
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
271
  if uncond_prompt is not None:
272
  if isinstance(uncond_prompt, str):
273
  uncond_prompt = [uncond_prompt]
274
+ uncond_tokens = [
275
+ token[1:-1]
276
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
277
+ ]
278
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
279
 
280
  # round up the longest length of tokens to a multiple of (model_max_length - 2)
281
  max_length = max([len(token) for token in prompt_tokens])
282
  if uncond_prompt is not None:
283
  max_length = max(max_length, max([len(token) for token in uncond_tokens]))
284
 
285
+ max_embeddings_multiples = min(
286
+ max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
287
+ )
288
  max_embeddings_multiples = max(1, max_embeddings_multiples)
289
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
290
 
291
  # pad the length of tokens and weights
292
  bos = pipe.tokenizer.bos_token_id
293
  eos = pipe.tokenizer.eos_token_id
294
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
295
+ prompt_tokens,
296
+ prompt_weights,
297
+ max_length,
298
+ bos,
299
+ eos,
300
+ no_boseos_middle=no_boseos_middle,
301
+ chunk_length=pipe.tokenizer.model_max_length,
302
+ )
303
  prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
304
  if uncond_prompt is not None:
305
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
306
+ uncond_tokens,
307
+ uncond_weights,
308
+ max_length,
309
+ bos,
310
+ eos,
311
+ no_boseos_middle=no_boseos_middle,
312
+ chunk_length=pipe.tokenizer.model_max_length,
313
+ )
314
  uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
315
 
316
  # get the embeddings
317
+ text_embeddings = get_unweighted_text_embeddings(
318
+ pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
319
+ )
320
  prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
321
  if uncond_prompt is not None:
322
+ uncond_embeddings = get_unweighted_text_embeddings(
323
+ pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
324
+ )
325
  uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
326
 
327
  # assign weights to the prompts and normalize in the sense of mean
 
393
  """
394
 
395
  def __init__(
396
+ self,
397
+ vae: AutoencoderKL,
398
+ text_encoder: CLIPTextModel,
399
+ tokenizer: CLIPTokenizer,
400
+ unet: UNet2DConditionModel,
401
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
402
+ safety_checker: StableDiffusionSafetyChecker,
403
+ feature_extractor: CLIPFeatureExtractor,
404
  ):
405
  super().__init__()
406
 
 
467
 
468
  @torch.no_grad()
469
  def __call__(
470
+ self,
471
+ prompt: Union[str, List[str]],
472
+ negative_prompt: Optional[Union[str, List[str]]] = None,
473
+ init_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
474
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
475
+ height: int = 512,
476
+ width: int = 512,
477
+ num_inference_steps: int = 50,
478
+ guidance_scale: float = 7.5,
479
+ strength: float = 0.8,
480
+ num_images_per_prompt: Optional[int] = 1,
481
+ eta: float = 0.0,
482
+ generator: Optional[torch.Generator] = None,
483
+ latents: Optional[torch.FloatTensor] = None,
484
+ max_embeddings_multiples: Optional[int] = 3,
485
+ output_type: Optional[str] = "pil",
486
+ return_dict: bool = True,
487
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
488
+ callback_steps: Optional[int] = 1,
489
+ **kwargs,
490
  ):
491
  r"""
492
  Function invoked when calling the pipeline for generation.
 
574
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
575
 
576
  if (callback_steps is None) or (
577
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
578
  ):
579
  raise ValueError(
580
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
 
599
  " the batch size of `prompt`."
600
  )
601
 
602
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
603
  pipe=self,
604
  prompt=prompt,
605
  uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
606
  max_embeddings_multiples=max_embeddings_multiples,
607
+ **kwargs,
608
  )
609
  bs_embed, seq_len, _ = text_embeddings.shape
610
  text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
 
753
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
754
 
755
  def text2img(
756
+ self,
757
+ prompt: Union[str, List[str]],
758
+ negative_prompt: Optional[Union[str, List[str]]] = None,
759
+ height: int = 512,
760
+ width: int = 512,
761
+ num_inference_steps: int = 50,
762
+ guidance_scale: float = 7.5,
763
+ num_images_per_prompt: Optional[int] = 1,
764
+ eta: float = 0.0,
765
+ generator: Optional[torch.Generator] = None,
766
+ latents: Optional[torch.FloatTensor] = None,
767
+ max_embeddings_multiples: Optional[int] = 3,
768
+ output_type: Optional[str] = "pil",
769
+ return_dict: bool = True,
770
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
771
+ callback_steps: Optional[int] = 1,
772
+ **kwargs,
773
  ):
774
  r"""
775
  Function for text-to-image generation.
 
841
  return_dict=return_dict,
842
  callback=callback,
843
  callback_steps=callback_steps,
844
+ **kwargs,
845
  )
846
 
847
  def img2img(
848
+ self,
849
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
850
+ prompt: Union[str, List[str]],
851
+ negative_prompt: Optional[Union[str, List[str]]] = None,
852
+ strength: float = 0.8,
853
+ num_inference_steps: Optional[int] = 50,
854
+ guidance_scale: Optional[float] = 7.5,
855
+ num_images_per_prompt: Optional[int] = 1,
856
+ eta: Optional[float] = 0.0,
857
+ generator: Optional[torch.Generator] = None,
858
+ max_embeddings_multiples: Optional[int] = 3,
859
+ output_type: Optional[str] = "pil",
860
+ return_dict: bool = True,
861
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
862
+ callback_steps: Optional[int] = 1,
863
+ **kwargs,
864
  ):
865
  r"""
866
  Function for image-to-image generation.
 
932
  return_dict=return_dict,
933
  callback=callback,
934
  callback_steps=callback_steps,
935
+ **kwargs,
936
  )
937
 
938
  def inpaint(
939
+ self,
940
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
941
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
942
+ prompt: Union[str, List[str]],
943
+ negative_prompt: Optional[Union[str, List[str]]] = None,
944
+ strength: float = 0.8,
945
+ num_inference_steps: Optional[int] = 50,
946
+ guidance_scale: Optional[float] = 7.5,
947
+ num_images_per_prompt: Optional[int] = 1,
948
+ eta: Optional[float] = 0.0,
949
+ generator: Optional[torch.Generator] = None,
950
+ max_embeddings_multiples: Optional[int] = 3,
951
+ output_type: Optional[str] = "pil",
952
+ return_dict: bool = True,
953
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
954
+ callback_steps: Optional[int] = 1,
955
+ **kwargs,
956
  ):
957
  r"""
958
  Function for inpaint.
 
1029
  return_dict=return_dict,
1030
  callback=callback,
1031
  callback_steps=callback_steps,
1032
+ **kwargs,
1033
  )