Add TextualInversion to processed prompts. Other updates.
Browse files- pipeline.py +18 -9
pipeline.py
CHANGED
@@ -6,13 +6,14 @@ import numpy as np
|
|
6 |
import PIL
|
7 |
import torch
|
8 |
from packaging import version
|
9 |
-
from transformers import
|
10 |
import random
|
11 |
import sys
|
12 |
from tqdm.auto import tqdm
|
13 |
|
14 |
import diffusers
|
15 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
|
|
16 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
17 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
18 |
from diffusers.utils import logging
|
@@ -182,14 +183,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m
|
|
182 |
return tokens, weights
|
183 |
|
184 |
|
185 |
-
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
186 |
r"""
|
187 |
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
188 |
"""
|
189 |
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
190 |
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
191 |
for i in range(len(tokens)):
|
192 |
-
tokens[i] = [bos] + tokens[i] + [
|
193 |
if no_boseos_middle:
|
194 |
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
195 |
else:
|
@@ -320,12 +321,14 @@ def get_weighted_text_embeddings(
|
|
320 |
# pad the length of tokens and weights
|
321 |
bos = pipe.tokenizer.bos_token_id
|
322 |
eos = pipe.tokenizer.eos_token_id
|
|
|
323 |
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
324 |
prompt_tokens,
|
325 |
prompt_weights,
|
326 |
max_length,
|
327 |
bos,
|
328 |
eos,
|
|
|
329 |
no_boseos_middle=no_boseos_middle,
|
330 |
chunk_length=pipe.tokenizer.model_max_length,
|
331 |
)
|
@@ -337,6 +340,7 @@ def get_weighted_text_embeddings(
|
|
337 |
max_length,
|
338 |
bos,
|
339 |
eos,
|
|
|
340 |
no_boseos_middle=no_boseos_middle,
|
341 |
chunk_length=pipe.tokenizer.model_max_length,
|
342 |
)
|
@@ -379,7 +383,7 @@ def get_weighted_text_embeddings(
|
|
379 |
|
380 |
def preprocess_image(image):
|
381 |
w, h = image.size
|
382 |
-
w, h =
|
383 |
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
384 |
image = np.array(image).astype(np.float32) / 255.0
|
385 |
image = image[None].transpose(0, 3, 1, 2)
|
@@ -390,7 +394,7 @@ def preprocess_image(image):
|
|
390 |
def preprocess_mask(mask, scale_factor=8):
|
391 |
mask = mask.convert("L")
|
392 |
w, h = mask.size
|
393 |
-
w, h =
|
394 |
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
395 |
mask = np.array(mask).astype(np.float32) / 255.0
|
396 |
mask = np.tile(mask, (4, 1, 1))
|
@@ -425,7 +429,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
425 |
safety_checker ([`StableDiffusionSafetyChecker`]):
|
426 |
Classification module that estimates whether generated images could be considered offensive or harmful.
|
427 |
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
428 |
-
feature_extractor ([`
|
429 |
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
430 |
"""
|
431 |
|
@@ -439,7 +443,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
439 |
unet: UNet2DConditionModel,
|
440 |
scheduler: SchedulerMixin,
|
441 |
safety_checker: StableDiffusionSafetyChecker,
|
442 |
-
feature_extractor:
|
443 |
requires_safety_checker: bool = True,
|
444 |
):
|
445 |
super().__init__(
|
@@ -464,7 +468,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
464 |
unet: UNet2DConditionModel,
|
465 |
scheduler: SchedulerMixin,
|
466 |
safety_checker: StableDiffusionSafetyChecker,
|
467 |
-
feature_extractor:
|
468 |
):
|
469 |
super().__init__(
|
470 |
vae=vae,
|
@@ -538,6 +542,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
538 |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
539 |
" the batch size of `prompt`."
|
540 |
)
|
|
|
|
|
|
|
|
|
|
|
541 |
|
542 |
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
543 |
pipe=self,
|
@@ -627,7 +636,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
627 |
if image is None:
|
628 |
shape = (
|
629 |
batch_size,
|
630 |
-
self.unet.in_channels,
|
631 |
height // self.vae_scale_factor,
|
632 |
width // self.vae_scale_factor,
|
633 |
)
|
|
|
6 |
import PIL
|
7 |
import torch
|
8 |
from packaging import version
|
9 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
10 |
import random
|
11 |
import sys
|
12 |
from tqdm.auto import tqdm
|
13 |
|
14 |
import diffusers
|
15 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
16 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
17 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
18 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
19 |
from diffusers.utils import logging
|
|
|
183 |
return tokens, weights
|
184 |
|
185 |
|
186 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
187 |
r"""
|
188 |
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
189 |
"""
|
190 |
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
191 |
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
192 |
for i in range(len(tokens)):
|
193 |
+
tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
|
194 |
if no_boseos_middle:
|
195 |
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
196 |
else:
|
|
|
321 |
# pad the length of tokens and weights
|
322 |
bos = pipe.tokenizer.bos_token_id
|
323 |
eos = pipe.tokenizer.eos_token_id
|
324 |
+
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
|
325 |
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
326 |
prompt_tokens,
|
327 |
prompt_weights,
|
328 |
max_length,
|
329 |
bos,
|
330 |
eos,
|
331 |
+
pad,
|
332 |
no_boseos_middle=no_boseos_middle,
|
333 |
chunk_length=pipe.tokenizer.model_max_length,
|
334 |
)
|
|
|
340 |
max_length,
|
341 |
bos,
|
342 |
eos,
|
343 |
+
pad,
|
344 |
no_boseos_middle=no_boseos_middle,
|
345 |
chunk_length=pipe.tokenizer.model_max_length,
|
346 |
)
|
|
|
383 |
|
384 |
def preprocess_image(image):
|
385 |
w, h = image.size
|
386 |
+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
387 |
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
388 |
image = np.array(image).astype(np.float32) / 255.0
|
389 |
image = image[None].transpose(0, 3, 1, 2)
|
|
|
394 |
def preprocess_mask(mask, scale_factor=8):
|
395 |
mask = mask.convert("L")
|
396 |
w, h = mask.size
|
397 |
+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
398 |
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
399 |
mask = np.array(mask).astype(np.float32) / 255.0
|
400 |
mask = np.tile(mask, (4, 1, 1))
|
|
|
429 |
safety_checker ([`StableDiffusionSafetyChecker`]):
|
430 |
Classification module that estimates whether generated images could be considered offensive or harmful.
|
431 |
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
432 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
433 |
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
434 |
"""
|
435 |
|
|
|
443 |
unet: UNet2DConditionModel,
|
444 |
scheduler: SchedulerMixin,
|
445 |
safety_checker: StableDiffusionSafetyChecker,
|
446 |
+
feature_extractor: CLIPImageProcessor,
|
447 |
requires_safety_checker: bool = True,
|
448 |
):
|
449 |
super().__init__(
|
|
|
468 |
unet: UNet2DConditionModel,
|
469 |
scheduler: SchedulerMixin,
|
470 |
safety_checker: StableDiffusionSafetyChecker,
|
471 |
+
feature_extractor: CLIPImageProcessor,
|
472 |
):
|
473 |
super().__init__(
|
474 |
vae=vae,
|
|
|
542 |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
543 |
" the batch size of `prompt`."
|
544 |
)
|
545 |
+
|
546 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
547 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
548 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
549 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
|
550 |
|
551 |
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
552 |
pipe=self,
|
|
|
636 |
if image is None:
|
637 |
shape = (
|
638 |
batch_size,
|
639 |
+
self.unet.config.in_channels,
|
640 |
height // self.vae_scale_factor,
|
641 |
width // self.vae_scale_factor,
|
642 |
)
|