|
import torch |
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
class MyPipeline(DiffusionPipeline): |
|
def __init__(self, unet, scheduler): |
|
super().__init__() |
|
|
|
self.register_modules(unet=unet, scheduler=scheduler) |
|
|
|
@torch.no_grad() |
|
def __call__(self, batch_size: int = 1, num_inference_steps: int = 50): |
|
|
|
image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)) |
|
|
|
image = image.to(self.device) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
|
|
for t in self.progress_bar(self.scheduler.timesteps): |
|
|
|
model_output = self.unet(image, t).sample |
|
|
|
|
|
|
|
|
|
image = self.scheduler.step(model_output, t, image).prev_sample |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
|
|
return image |
|
|