from typing import List, Optional, Union import torch from diffusers import AutoencoderKL, StableDiffusionXLPipeline, UNet2DConditionModel from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( retrieve_timesteps, ) from diffusers.schedulers import KarrasDiffusionSchedulers from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) def freeze_params(params): for param in params: param.requires_grad = False class RewardStableDiffusionXL(StableDiffusionXLPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: bool = False, is_hyper: bool = False, memsave: bool = False, ): super().__init__( vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, unet, scheduler, image_encoder, feature_extractor, force_zeros_for_empty_prompt, add_watermarker, ) # optionally enable memsave_torch if memsave: import memsave_torch.nn self.vae = memsave_torch.nn.convert_to_memory_saving(self.vae) self.unet = memsave_torch.nn.convert_to_memory_saving(self.unet) self.text_encoder = memsave_torch.nn.convert_to_memory_saving( self.text_encoder ) self.text_encoder_2 = memsave_torch.nn.convert_to_memory_saving( self.text_encoder_2 ) # enable checkpointing self.unet.enable_gradient_checkpointing() self.vae.enable_gradient_checkpointing() self.text_encoder.eval() self.text_encoder_2.eval() self.unet.eval() self.vae.eval() self.is_hyper = is_hyper # freeze diffusion parameters freeze_params(self.vae.parameters()) freeze_params(self.unet.parameters()) freeze_params(self.text_encoder.parameters()) freeze_params(self.text_encoder_2.parameters()) def decode_latents_tensors(self, latents): latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) return image def apply( self, latents: torch.Tensor, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 1, guidance_scale: float = 0.0, timesteps: List[int] = None, denoising_end: Optional[float] = None, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, ) -> torch.Tensor: if self.is_hyper: timesteps = [800] # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = (height, width) target_size = (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps=1, ) # 2. Define call parameters self._guidance_scale = guidance_scale self._clip_skip = 0 self._cross_attention_kwargs = None self._denoising_end = denoising_end self._interrupt = False # 2. Define call parameters batch_size = 1 device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds = None negative_prompt_embeds = None pooled_prompt_embeds = None negative_pooled_prompt_embeds = None ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps ) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, (0, 0), target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, add_text_embeds], dim=0 ) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat( batch_size * num_images_per_prompt, 1 ) # 8. Denoising loop num_warmup_steps = max( len(timesteps) - num_inference_steps * self.scheduler.order, 0 ) # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len( list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) ) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( batch_size * num_images_per_prompt ) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) # 8. Denoising loop # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len( list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) ) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None self._num_timesteps = len(timesteps) for i, t in enumerate(timesteps): if self._interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = { "text_embeds": add_text_embeds, "time_ids": add_time_ids, } noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False )[0] if self.is_hyper: latents = latents.to(torch.float32) image = self.decode_latents_tensors(latents) image = image.to(torch.float16) else: image = self.decode_latents_tensors(latents) # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) # Offload all models self.maybe_free_model_hooks() return image