ahmedfaiyaz commited on
Commit
5b5781c
1 Parent(s): 5f857dc

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +91 -0
pipeline.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ from diffusers.utils.torch_utils import randn_tensor
4
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
5
+
6
+
7
+ class OkkhorDiffusionPipeline(DiffusionPipeline):
8
+ r"""
9
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
10
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
11
+
12
+ Parameters:
13
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
14
+ scheduler ([`SchedulerMixin`]):
15
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
16
+ [`DDPMScheduler`], or [`DDIMScheduler`].
17
+ """
18
+
19
+ def __init__(self, unet, scheduler,embedding):
20
+ super().__init__()
21
+ self.register_modules(unet=unet, scheduler=scheduler,embedding = embedding)
22
+
23
+
24
+ @torch.no_grad()
25
+ def __call__(
26
+ self,
27
+ batch_size: int = 1,
28
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
29
+ num_inference_steps: int = 1000,
30
+ output_type: Optional[str] = "pil",
31
+ return_dict: bool = True,
32
+ ) -> Union[ImagePipelineOutput, Tuple]:
33
+ r"""
34
+ Args:
35
+ batch_size (`int`, *optional*, defaults to 1):
36
+ The number of images to generate.
37
+ generator (`torch.Generator`, *optional*):
38
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
39
+ to make generation deterministic.
40
+ num_inference_steps (`int`, *optional*, defaults to 1000):
41
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
42
+ expense of slower inference.
43
+ output_type (`str`, *optional*, defaults to `"pil"`):
44
+ The output format of the generate image. Choose between
45
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
46
+ return_dict (`bool`, *optional*, defaults to `True`):
47
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
48
+
49
+ Returns:
50
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
51
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
52
+ """
53
+ # Sample gaussian noise to begin loop
54
+ if isinstance(self.unet.config.sample_size, int):
55
+ image_shape = (
56
+ batch_size,
57
+ self.unet.config.in_channels,
58
+ self.unet.config.sample_size,
59
+ self.unet.config.sample_size,
60
+ )
61
+ else:
62
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
63
+
64
+ if self.device.type == "mps":
65
+ # randn does not work reproducibly on mps
66
+ image = randn_tensor(image_shape, generator=generator)
67
+ image = image.to(self.device)
68
+ else:
69
+ image = randn_tensor(image_shape, generator=generator, device=self.device)
70
+ if self.embedding:
71
+ self.embedding=self.embedding.to(self.device)
72
+
73
+ # set step values
74
+ self.scheduler.set_timesteps(num_inference_steps)
75
+
76
+ for t in self.progress_bar(self.scheduler.timesteps):
77
+ # 1. predict noise model_output
78
+ model_output = self.unet(image, t,class_labels=self.embedding).sample
79
+
80
+ # 2. compute previous image: x_t -> x_t-1
81
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
82
+
83
+ image = (image / 2 + 0.5).clamp(0, 1)
84
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
85
+ if output_type == "pil":
86
+ image = self.numpy_to_pil(image)
87
+
88
+ if not return_dict:
89
+ return (image,)
90
+
91
+ return ImagePipelineOutput(images=image)