Update pipeline.py
Browse files- pipeline.py +149 -138
pipeline.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import inspect
|
2 |
import re
|
3 |
-
from typing import Callable, List, Optional, Union
|
4 |
import PIL
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
-
|
8 |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
9 |
|
10 |
from diffusers.configuration_utils import FrozenDict
|
@@ -17,7 +16,8 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
|
17 |
|
18 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
19 |
|
20 |
-
re_attention = re.compile(
|
|
|
21 |
\\\(|
|
22 |
\\\)|
|
23 |
\\\[|
|
@@ -31,7 +31,9 @@ re_attention = re.compile(r"""
|
|
31 |
]|
|
32 |
[^\\()\[\]:]+|
|
33 |
:
|
34 |
-
""",
|
|
|
|
|
35 |
|
36 |
|
37 |
def parse_prompt_attention(text):
|
@@ -84,17 +86,17 @@ def parse_prompt_attention(text):
|
|
84 |
text = m.group(0)
|
85 |
weight = m.group(1)
|
86 |
|
87 |
-
if text.startswith(
|
88 |
res.append([text[1:], 1.0])
|
89 |
-
elif text ==
|
90 |
round_brackets.append(len(res))
|
91 |
-
elif text ==
|
92 |
square_brackets.append(len(res))
|
93 |
elif weight is not None and len(round_brackets) > 0:
|
94 |
multiply_range(round_brackets.pop(), float(weight))
|
95 |
-
elif text ==
|
96 |
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
97 |
-
elif text ==
|
98 |
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
99 |
else:
|
100 |
res.append([text, 1.0])
|
@@ -120,11 +122,7 @@ def parse_prompt_attention(text):
|
|
120 |
return res
|
121 |
|
122 |
|
123 |
-
def get_prompts_with_weights(
|
124 |
-
pipe: DiffusionPipeline,
|
125 |
-
prompt: List[str],
|
126 |
-
max_length: int
|
127 |
-
):
|
128 |
r"""
|
129 |
Tokenize a list of prompts and return its tokens with weights of each token.
|
130 |
|
@@ -158,9 +156,7 @@ def get_prompts_with_weights(
|
|
158 |
return tokens, weights
|
159 |
|
160 |
|
161 |
-
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos,
|
162 |
-
no_boseos_middle=True,
|
163 |
-
chunk_length=77):
|
164 |
r"""
|
165 |
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
166 |
"""
|
@@ -169,27 +165,24 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos,
|
|
169 |
for i in range(len(tokens)):
|
170 |
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
171 |
if no_boseos_middle:
|
172 |
-
weights[i] = [1.] + weights[i] + [1.] * (max_length - 1 - len(weights[i]))
|
173 |
else:
|
174 |
w = []
|
175 |
if len(weights[i]) == 0:
|
176 |
-
w = [1.] * weights_length
|
177 |
else:
|
178 |
for j in range((len(weights[i]) - 1) // chunk_length + 1):
|
179 |
-
w.append(1.) # weight for starting token in this chunk
|
180 |
-
w += weights[i][j * chunk_length: min(len(weights[i]), (j + 1) * chunk_length)]
|
181 |
-
w.append(1.) # weight for ending token in this chunk
|
182 |
-
w += [1.] * (weights_length - len(w))
|
183 |
weights[i] = w[:]
|
184 |
|
185 |
return tokens, weights
|
186 |
|
187 |
|
188 |
def get_unweighted_text_embeddings(
|
189 |
-
|
190 |
-
text_input: torch.Tensor,
|
191 |
-
chunk_length: int,
|
192 |
-
no_boseos_middle: Optional[bool] = True
|
193 |
):
|
194 |
"""
|
195 |
When the length of tokens is a multiple of the capacity of the text encoder,
|
@@ -200,7 +193,7 @@ def get_unweighted_text_embeddings(
|
|
200 |
text_embeddings = []
|
201 |
for i in range(max_embeddings_multiples):
|
202 |
# extract the i-th chunk
|
203 |
-
text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1) * (chunk_length - 2) + 2].clone()
|
204 |
|
205 |
# cover the head and the tail by the starting and the ending tokens
|
206 |
text_input_chunk[:, 0] = text_input[0, 0]
|
@@ -226,14 +219,14 @@ def get_unweighted_text_embeddings(
|
|
226 |
|
227 |
|
228 |
def get_weighted_text_embeddings(
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
):
|
238 |
r"""
|
239 |
Prompts can be assigned with local weights using brackets. For example,
|
@@ -271,46 +264,64 @@ def get_weighted_text_embeddings(
|
|
271 |
uncond_prompt = [uncond_prompt]
|
272 |
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
273 |
else:
|
274 |
-
prompt_tokens = [
|
275 |
-
|
276 |
-
|
|
|
277 |
if uncond_prompt is not None:
|
278 |
if isinstance(uncond_prompt, str):
|
279 |
uncond_prompt = [uncond_prompt]
|
280 |
-
uncond_tokens = [
|
281 |
-
|
282 |
-
|
|
|
|
|
283 |
|
284 |
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
285 |
max_length = max([len(token) for token in prompt_tokens])
|
286 |
if uncond_prompt is not None:
|
287 |
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
288 |
|
289 |
-
max_embeddings_multiples = min(
|
290 |
-
|
|
|
291 |
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
292 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
293 |
|
294 |
# pad the length of tokens and weights
|
295 |
bos = pipe.tokenizer.bos_token_id
|
296 |
eos = pipe.tokenizer.eos_token_id
|
297 |
-
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
298 |
-
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
301 |
if uncond_prompt is not None:
|
302 |
-
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
303 |
-
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
306 |
|
307 |
# get the embeddings
|
308 |
-
text_embeddings = get_unweighted_text_embeddings(
|
309 |
-
|
|
|
310 |
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
311 |
if uncond_prompt is not None:
|
312 |
-
uncond_embeddings = get_unweighted_text_embeddings(
|
313 |
-
|
|
|
314 |
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
315 |
|
316 |
# assign weights to the prompts and normalize in the sense of mean
|
@@ -382,14 +393,14 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
382 |
"""
|
383 |
|
384 |
def __init__(
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
):
|
394 |
super().__init__()
|
395 |
|
@@ -456,26 +467,26 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
456 |
|
457 |
@torch.no_grad()
|
458 |
def __call__(
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
):
|
480 |
r"""
|
481 |
Function invoked when calling the pipeline for generation.
|
@@ -563,7 +574,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
563 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
564 |
|
565 |
if (callback_steps is None) or (
|
566 |
-
|
567 |
):
|
568 |
raise ValueError(
|
569 |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
@@ -588,12 +599,12 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
588 |
" the batch size of `prompt`."
|
589 |
)
|
590 |
|
591 |
-
text_embeddings, uncond_embeddings
|
592 |
pipe=self,
|
593 |
prompt=prompt,
|
594 |
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
595 |
max_embeddings_multiples=max_embeddings_multiples,
|
596 |
-
**kwargs
|
597 |
)
|
598 |
bs_embed, seq_len, _ = text_embeddings.shape
|
599 |
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
@@ -742,23 +753,23 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
742 |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
743 |
|
744 |
def text2img(
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
):
|
763 |
r"""
|
764 |
Function for text-to-image generation.
|
@@ -830,26 +841,26 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
830 |
return_dict=return_dict,
|
831 |
callback=callback,
|
832 |
callback_steps=callback_steps,
|
833 |
-
**kwargs
|
834 |
)
|
835 |
|
836 |
def img2img(
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
):
|
854 |
r"""
|
855 |
Function for image-to-image generation.
|
@@ -921,27 +932,27 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
921 |
return_dict=return_dict,
|
922 |
callback=callback,
|
923 |
callback_steps=callback_steps,
|
924 |
-
**kwargs
|
925 |
)
|
926 |
|
927 |
def inpaint(
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
):
|
946 |
r"""
|
947 |
Function for inpaint.
|
@@ -1018,5 +1029,5 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
1018 |
return_dict=return_dict,
|
1019 |
callback=callback,
|
1020 |
callback_steps=callback_steps,
|
1021 |
-
**kwargs
|
1022 |
)
|
|
|
1 |
import inspect
|
2 |
import re
|
|
|
3 |
import PIL
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
+
from typing import Callable, List, Optional, Union
|
7 |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
8 |
|
9 |
from diffusers.configuration_utils import FrozenDict
|
|
|
16 |
|
17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
18 |
|
19 |
+
re_attention = re.compile(
|
20 |
+
r"""
|
21 |
\\\(|
|
22 |
\\\)|
|
23 |
\\\[|
|
|
|
31 |
]|
|
32 |
[^\\()\[\]:]+|
|
33 |
:
|
34 |
+
""",
|
35 |
+
re.X,
|
36 |
+
)
|
37 |
|
38 |
|
39 |
def parse_prompt_attention(text):
|
|
|
86 |
text = m.group(0)
|
87 |
weight = m.group(1)
|
88 |
|
89 |
+
if text.startswith("\\"):
|
90 |
res.append([text[1:], 1.0])
|
91 |
+
elif text == "(":
|
92 |
round_brackets.append(len(res))
|
93 |
+
elif text == "[":
|
94 |
square_brackets.append(len(res))
|
95 |
elif weight is not None and len(round_brackets) > 0:
|
96 |
multiply_range(round_brackets.pop(), float(weight))
|
97 |
+
elif text == ")" and len(round_brackets) > 0:
|
98 |
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
99 |
+
elif text == "]" and len(square_brackets) > 0:
|
100 |
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
101 |
else:
|
102 |
res.append([text, 1.0])
|
|
|
122 |
return res
|
123 |
|
124 |
|
125 |
+
def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
|
|
|
|
|
|
|
|
|
126 |
r"""
|
127 |
Tokenize a list of prompts and return its tokens with weights of each token.
|
128 |
|
|
|
156 |
return tokens, weights
|
157 |
|
158 |
|
159 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
|
|
|
|
160 |
r"""
|
161 |
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
162 |
"""
|
|
|
165 |
for i in range(len(tokens)):
|
166 |
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
167 |
if no_boseos_middle:
|
168 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
169 |
else:
|
170 |
w = []
|
171 |
if len(weights[i]) == 0:
|
172 |
+
w = [1.0] * weights_length
|
173 |
else:
|
174 |
for j in range((len(weights[i]) - 1) // chunk_length + 1):
|
175 |
+
w.append(1.0) # weight for starting token in this chunk
|
176 |
+
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
|
177 |
+
w.append(1.0) # weight for ending token in this chunk
|
178 |
+
w += [1.0] * (weights_length - len(w))
|
179 |
weights[i] = w[:]
|
180 |
|
181 |
return tokens, weights
|
182 |
|
183 |
|
184 |
def get_unweighted_text_embeddings(
|
185 |
+
pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
|
|
|
|
|
|
|
186 |
):
|
187 |
"""
|
188 |
When the length of tokens is a multiple of the capacity of the text encoder,
|
|
|
193 |
text_embeddings = []
|
194 |
for i in range(max_embeddings_multiples):
|
195 |
# extract the i-th chunk
|
196 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
197 |
|
198 |
# cover the head and the tail by the starting and the ending tokens
|
199 |
text_input_chunk[:, 0] = text_input[0, 0]
|
|
|
219 |
|
220 |
|
221 |
def get_weighted_text_embeddings(
|
222 |
+
pipe: DiffusionPipeline,
|
223 |
+
prompt: Union[str, List[str]],
|
224 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
225 |
+
max_embeddings_multiples: Optional[int] = 1,
|
226 |
+
no_boseos_middle: Optional[bool] = False,
|
227 |
+
skip_parsing: Optional[bool] = False,
|
228 |
+
skip_weighting: Optional[bool] = False,
|
229 |
+
**kwargs,
|
230 |
):
|
231 |
r"""
|
232 |
Prompts can be assigned with local weights using brackets. For example,
|
|
|
264 |
uncond_prompt = [uncond_prompt]
|
265 |
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
266 |
else:
|
267 |
+
prompt_tokens = [
|
268 |
+
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
|
269 |
+
]
|
270 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
271 |
if uncond_prompt is not None:
|
272 |
if isinstance(uncond_prompt, str):
|
273 |
uncond_prompt = [uncond_prompt]
|
274 |
+
uncond_tokens = [
|
275 |
+
token[1:-1]
|
276 |
+
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
277 |
+
]
|
278 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
279 |
|
280 |
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
281 |
max_length = max([len(token) for token in prompt_tokens])
|
282 |
if uncond_prompt is not None:
|
283 |
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
284 |
|
285 |
+
max_embeddings_multiples = min(
|
286 |
+
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
|
287 |
+
)
|
288 |
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
289 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
290 |
|
291 |
# pad the length of tokens and weights
|
292 |
bos = pipe.tokenizer.bos_token_id
|
293 |
eos = pipe.tokenizer.eos_token_id
|
294 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
295 |
+
prompt_tokens,
|
296 |
+
prompt_weights,
|
297 |
+
max_length,
|
298 |
+
bos,
|
299 |
+
eos,
|
300 |
+
no_boseos_middle=no_boseos_middle,
|
301 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
302 |
+
)
|
303 |
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
304 |
if uncond_prompt is not None:
|
305 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
306 |
+
uncond_tokens,
|
307 |
+
uncond_weights,
|
308 |
+
max_length,
|
309 |
+
bos,
|
310 |
+
eos,
|
311 |
+
no_boseos_middle=no_boseos_middle,
|
312 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
313 |
+
)
|
314 |
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
315 |
|
316 |
# get the embeddings
|
317 |
+
text_embeddings = get_unweighted_text_embeddings(
|
318 |
+
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
319 |
+
)
|
320 |
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
321 |
if uncond_prompt is not None:
|
322 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
323 |
+
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
324 |
+
)
|
325 |
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
326 |
|
327 |
# assign weights to the prompts and normalize in the sense of mean
|
|
|
393 |
"""
|
394 |
|
395 |
def __init__(
|
396 |
+
self,
|
397 |
+
vae: AutoencoderKL,
|
398 |
+
text_encoder: CLIPTextModel,
|
399 |
+
tokenizer: CLIPTokenizer,
|
400 |
+
unet: UNet2DConditionModel,
|
401 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
402 |
+
safety_checker: StableDiffusionSafetyChecker,
|
403 |
+
feature_extractor: CLIPFeatureExtractor,
|
404 |
):
|
405 |
super().__init__()
|
406 |
|
|
|
467 |
|
468 |
@torch.no_grad()
|
469 |
def __call__(
|
470 |
+
self,
|
471 |
+
prompt: Union[str, List[str]],
|
472 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
473 |
+
init_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
474 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
475 |
+
height: int = 512,
|
476 |
+
width: int = 512,
|
477 |
+
num_inference_steps: int = 50,
|
478 |
+
guidance_scale: float = 7.5,
|
479 |
+
strength: float = 0.8,
|
480 |
+
num_images_per_prompt: Optional[int] = 1,
|
481 |
+
eta: float = 0.0,
|
482 |
+
generator: Optional[torch.Generator] = None,
|
483 |
+
latents: Optional[torch.FloatTensor] = None,
|
484 |
+
max_embeddings_multiples: Optional[int] = 3,
|
485 |
+
output_type: Optional[str] = "pil",
|
486 |
+
return_dict: bool = True,
|
487 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
488 |
+
callback_steps: Optional[int] = 1,
|
489 |
+
**kwargs,
|
490 |
):
|
491 |
r"""
|
492 |
Function invoked when calling the pipeline for generation.
|
|
|
574 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
575 |
|
576 |
if (callback_steps is None) or (
|
577 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
578 |
):
|
579 |
raise ValueError(
|
580 |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
|
599 |
" the batch size of `prompt`."
|
600 |
)
|
601 |
|
602 |
+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
603 |
pipe=self,
|
604 |
prompt=prompt,
|
605 |
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
606 |
max_embeddings_multiples=max_embeddings_multiples,
|
607 |
+
**kwargs,
|
608 |
)
|
609 |
bs_embed, seq_len, _ = text_embeddings.shape
|
610 |
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
|
|
753 |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
754 |
|
755 |
def text2img(
|
756 |
+
self,
|
757 |
+
prompt: Union[str, List[str]],
|
758 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
759 |
+
height: int = 512,
|
760 |
+
width: int = 512,
|
761 |
+
num_inference_steps: int = 50,
|
762 |
+
guidance_scale: float = 7.5,
|
763 |
+
num_images_per_prompt: Optional[int] = 1,
|
764 |
+
eta: float = 0.0,
|
765 |
+
generator: Optional[torch.Generator] = None,
|
766 |
+
latents: Optional[torch.FloatTensor] = None,
|
767 |
+
max_embeddings_multiples: Optional[int] = 3,
|
768 |
+
output_type: Optional[str] = "pil",
|
769 |
+
return_dict: bool = True,
|
770 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
771 |
+
callback_steps: Optional[int] = 1,
|
772 |
+
**kwargs,
|
773 |
):
|
774 |
r"""
|
775 |
Function for text-to-image generation.
|
|
|
841 |
return_dict=return_dict,
|
842 |
callback=callback,
|
843 |
callback_steps=callback_steps,
|
844 |
+
**kwargs,
|
845 |
)
|
846 |
|
847 |
def img2img(
|
848 |
+
self,
|
849 |
+
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
850 |
+
prompt: Union[str, List[str]],
|
851 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
852 |
+
strength: float = 0.8,
|
853 |
+
num_inference_steps: Optional[int] = 50,
|
854 |
+
guidance_scale: Optional[float] = 7.5,
|
855 |
+
num_images_per_prompt: Optional[int] = 1,
|
856 |
+
eta: Optional[float] = 0.0,
|
857 |
+
generator: Optional[torch.Generator] = None,
|
858 |
+
max_embeddings_multiples: Optional[int] = 3,
|
859 |
+
output_type: Optional[str] = "pil",
|
860 |
+
return_dict: bool = True,
|
861 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
862 |
+
callback_steps: Optional[int] = 1,
|
863 |
+
**kwargs,
|
864 |
):
|
865 |
r"""
|
866 |
Function for image-to-image generation.
|
|
|
932 |
return_dict=return_dict,
|
933 |
callback=callback,
|
934 |
callback_steps=callback_steps,
|
935 |
+
**kwargs,
|
936 |
)
|
937 |
|
938 |
def inpaint(
|
939 |
+
self,
|
940 |
+
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
941 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
942 |
+
prompt: Union[str, List[str]],
|
943 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
944 |
+
strength: float = 0.8,
|
945 |
+
num_inference_steps: Optional[int] = 50,
|
946 |
+
guidance_scale: Optional[float] = 7.5,
|
947 |
+
num_images_per_prompt: Optional[int] = 1,
|
948 |
+
eta: Optional[float] = 0.0,
|
949 |
+
generator: Optional[torch.Generator] = None,
|
950 |
+
max_embeddings_multiples: Optional[int] = 3,
|
951 |
+
output_type: Optional[str] = "pil",
|
952 |
+
return_dict: bool = True,
|
953 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
954 |
+
callback_steps: Optional[int] = 1,
|
955 |
+
**kwargs,
|
956 |
):
|
957 |
r"""
|
958 |
Function for inpaint.
|
|
|
1029 |
return_dict=return_dict,
|
1030 |
callback=callback,
|
1031 |
callback_steps=callback_steps,
|
1032 |
+
**kwargs,
|
1033 |
)
|