Update pipeline.py
Browse files- pipeline.py +26 -27
pipeline.py
CHANGED
@@ -15,7 +15,6 @@ from diffusers.utils import deprecate, logging
|
|
15 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
16 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
17 |
|
18 |
-
|
19 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
20 |
|
21 |
re_attention = re.compile(r"""
|
@@ -387,14 +386,14 @@ class StableDiffusionLongPromptPipeline(DiffusionPipeline):
|
|
387 |
"""
|
388 |
|
389 |
def __init__(
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
):
|
399 |
super().__init__()
|
400 |
|
@@ -461,23 +460,23 @@ class StableDiffusionLongPromptPipeline(DiffusionPipeline):
|
|
461 |
|
462 |
@torch.no_grad()
|
463 |
def text2img(
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
):
|
482 |
r"""
|
483 |
Function invoked when calling the pipeline for generation.
|
@@ -547,7 +546,7 @@ class StableDiffusionLongPromptPipeline(DiffusionPipeline):
|
|
547 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
548 |
|
549 |
if (callback_steps is None) or (
|
550 |
-
|
551 |
):
|
552 |
raise ValueError(
|
553 |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
|
15 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
16 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
17 |
|
|
|
18 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
19 |
|
20 |
re_attention = re.compile(r"""
|
|
|
386 |
"""
|
387 |
|
388 |
def __init__(
|
389 |
+
self,
|
390 |
+
vae: AutoencoderKL,
|
391 |
+
text_encoder: CLIPTextModel,
|
392 |
+
tokenizer: CLIPTokenizer,
|
393 |
+
unet: UNet2DConditionModel,
|
394 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
395 |
+
safety_checker: StableDiffusionSafetyChecker,
|
396 |
+
feature_extractor: CLIPFeatureExtractor,
|
397 |
):
|
398 |
super().__init__()
|
399 |
|
|
|
460 |
|
461 |
@torch.no_grad()
|
462 |
def text2img(
|
463 |
+
self,
|
464 |
+
prompt: Union[str, List[str]],
|
465 |
+
height: int = 512,
|
466 |
+
width: int = 512,
|
467 |
+
num_inference_steps: int = 50,
|
468 |
+
guidance_scale: float = 7.5,
|
469 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
470 |
+
num_images_per_prompt: Optional[int] = 1,
|
471 |
+
eta: float = 0.0,
|
472 |
+
generator: Optional[torch.Generator] = None,
|
473 |
+
latents: Optional[torch.FloatTensor] = None,
|
474 |
+
max_embeddings_multiples: Optional[int] = 3,
|
475 |
+
output_type: Optional[str] = "pil",
|
476 |
+
return_dict: bool = True,
|
477 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
478 |
+
callback_steps: Optional[int] = 1,
|
479 |
+
**kwargs,
|
480 |
):
|
481 |
r"""
|
482 |
Function invoked when calling the pipeline for generation.
|
|
|
546 |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
547 |
|
548 |
if (callback_steps is None) or (
|
549 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
550 |
):
|
551 |
raise ValueError(
|
552 |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|