|
import torch |
|
from diffusers import AutoencoderKL, DiffusionPipeline |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from mv_unet import SPADUnetModel |
|
from diffusers.schedulers import DPMSolverMultistepScheduler |
|
|
|
class SPADPipeline(DiffusionPipeline): |
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
unet: SPADUnetModel, |
|
tokenizer: CLIPTokenizer, |
|
text_encoder: CLIPTextModel, |
|
scheduler: DPMSolverMultistepScheduler, |
|
): |
|
super().__init__() |
|
|
|
self.vae = vae |
|
self.unet = unet |
|
self.tokenizer = tokenizer |
|
self.text_encoder = text_encoder |
|
self.scheduler = scheduler |
|
|
|
|
|
self.vae.to(self.device) |
|
self.unet.to(self.device) |
|
self.text_encoder.to(self.device) |
|
|
|
def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None): |
|
text_input = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
return_tensors="pt" |
|
) |
|
text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
|
|
bs_embed, seq_len, _ = text_embeddings.shape |
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) |
|
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) |
|
|
|
return text_embeddings |
|
|
|
def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5): |
|
|
|
text_embeddings = self.encode_prompt(prompt, self.device, 1, do_classifier_free_guidance=False) |
|
|
|
|
|
latents = torch.randn( |
|
(text_embeddings.shape[0], self.unet.in_channels, self.unet.image_size, self.unet.image_size), |
|
device=self.device |
|
) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
|
|
|
|
for t in self.scheduler.timesteps: |
|
latents = self.scheduler.scale_model_input(latents, t) |
|
latents = self.unet(latents, t, text_embeddings)["sample"] |
|
latents = self.scheduler.step(latents, t, latents, guidance_scale=guidance_scale)["prev_sample"] |
|
|
|
|
|
images = self.vae.decode(latents) |
|
images = (images / 2 + 0.5).clamp(0, 1) |
|
|
|
return images |