skytnt commited on
Commit
eade452
1 Parent(s): fca2f57

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +26 -27
pipeline.py CHANGED
@@ -15,7 +15,6 @@ from diffusers.utils import deprecate, logging
15
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
17
 
18
-
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
21
  re_attention = re.compile(r"""
@@ -387,14 +386,14 @@ class StableDiffusionLongPromptPipeline(DiffusionPipeline):
387
  """
388
 
389
  def __init__(
390
- self,
391
- vae: AutoencoderKL,
392
- text_encoder: CLIPTextModel,
393
- tokenizer: CLIPTokenizer,
394
- unet: UNet2DConditionModel,
395
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
396
- safety_checker: StableDiffusionSafetyChecker,
397
- feature_extractor: CLIPFeatureExtractor,
398
  ):
399
  super().__init__()
400
 
@@ -461,23 +460,23 @@ class StableDiffusionLongPromptPipeline(DiffusionPipeline):
461
 
462
  @torch.no_grad()
463
  def text2img(
464
- self,
465
- prompt: Union[str, List[str]],
466
- height: int = 512,
467
- width: int = 512,
468
- num_inference_steps: int = 50,
469
- guidance_scale: float = 7.5,
470
- negative_prompt: Optional[Union[str, List[str]]] = None,
471
- num_images_per_prompt: Optional[int] = 1,
472
- eta: float = 0.0,
473
- generator: Optional[torch.Generator] = None,
474
- latents: Optional[torch.FloatTensor] = None,
475
- max_embeddings_multiples: Optional[int] = 3,
476
- output_type: Optional[str] = "pil",
477
- return_dict: bool = True,
478
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
479
- callback_steps: Optional[int] = 1,
480
- **kwargs,
481
  ):
482
  r"""
483
  Function invoked when calling the pipeline for generation.
@@ -547,7 +546,7 @@ class StableDiffusionLongPromptPipeline(DiffusionPipeline):
547
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
548
 
549
  if (callback_steps is None) or (
550
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
551
  ):
552
  raise ValueError(
553
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
 
15
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
17
 
 
18
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
 
20
  re_attention = re.compile(r"""
 
386
  """
387
 
388
  def __init__(
389
+ self,
390
+ vae: AutoencoderKL,
391
+ text_encoder: CLIPTextModel,
392
+ tokenizer: CLIPTokenizer,
393
+ unet: UNet2DConditionModel,
394
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
395
+ safety_checker: StableDiffusionSafetyChecker,
396
+ feature_extractor: CLIPFeatureExtractor,
397
  ):
398
  super().__init__()
399
 
 
460
 
461
  @torch.no_grad()
462
  def text2img(
463
+ self,
464
+ prompt: Union[str, List[str]],
465
+ height: int = 512,
466
+ width: int = 512,
467
+ num_inference_steps: int = 50,
468
+ guidance_scale: float = 7.5,
469
+ negative_prompt: Optional[Union[str, List[str]]] = None,
470
+ num_images_per_prompt: Optional[int] = 1,
471
+ eta: float = 0.0,
472
+ generator: Optional[torch.Generator] = None,
473
+ latents: Optional[torch.FloatTensor] = None,
474
+ max_embeddings_multiples: Optional[int] = 3,
475
+ output_type: Optional[str] = "pil",
476
+ return_dict: bool = True,
477
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
478
+ callback_steps: Optional[int] = 1,
479
+ **kwargs,
480
  ):
481
  r"""
482
  Function invoked when calling the pipeline for generation.
 
546
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
547
 
548
  if (callback_steps is None) or (
549
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
550
  ):
551
  raise ValueError(
552
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"