KingNish commited on
Commit
806b606
1 Parent(s): 4a3ad2f

Update custom_pipeline.py

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +148 -50
custom_pipeline.py CHANGED
@@ -1,8 +1,16 @@
1
- import torch
2
  import numpy as np
3
- from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
4
- from typing import Any, Dict, List, Optional, Union
5
- from PIL import Image
 
 
 
 
 
 
 
 
 
6
 
7
  # Constants for shift calculation
8
  BASE_SEQ_LEN = 256
@@ -19,7 +27,7 @@ def calculate_timestep_shift(image_seq_len: int) -> float:
19
  return mu
20
 
21
  def prepare_timesteps(
22
- scheduler: FlowMatchEulerDiscreteScheduler,
23
  num_inference_steps: Optional[int] = None,
24
  device: Optional[Union[str, torch.device]] = None,
25
  timesteps: Optional[List[int]] = None,
@@ -41,20 +49,23 @@ def prepare_timesteps(
41
  num_inference_steps = len(timesteps)
42
  return timesteps, num_inference_steps
43
 
44
- # FLUX pipeline function
45
- class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
46
  """
47
- Extends the FluxPipeline to yield intermediate images during the denoising process
48
- with progressively increasing resolution for faster generation.
 
49
  """
50
  @torch.inference_mode()
51
- def generate_images(
52
  self,
53
  prompt: Union[str, List[str]] = None,
54
  prompt_2: Optional[Union[str, List[str]]] = None,
 
 
55
  height: Optional[int] = None,
56
  width: Optional[int] = None,
57
- num_inference_steps: int = 4,
58
  timesteps: List[int] = None,
59
  guidance_scale: float = 3.5,
60
  num_images_per_prompt: Optional[int] = 1,
@@ -62,16 +73,21 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
62
  latents: Optional[torch.FloatTensor] = None,
63
  prompt_embeds: Optional[torch.FloatTensor] = None,
64
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
 
65
  output_type: Optional[str] = "pil",
66
  return_dict: bool = True,
67
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
68
- max_sequence_length: int = 300,
 
 
 
69
  ):
70
- """Generates images and yields intermediate results during the denoising process."""
71
  height = height or self.default_sample_size * self.vae_scale_factor
72
  width = width or self.default_sample_size * self.vae_scale_factor
73
 
74
- # 1. Check inputs
75
  self.check_inputs(
76
  prompt,
77
  prompt_2,
@@ -79,6 +95,7 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
79
  width,
80
  prompt_embeds=prompt_embeds,
81
  pooled_prompt_embeds=pooled_prompt_embeds,
 
82
  max_sequence_length=max_sequence_length,
83
  )
84
 
@@ -87,12 +104,23 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
87
  self._interrupt = False
88
 
89
  # 2. Define call parameters
90
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
 
 
 
 
 
 
91
  device = self._execution_device
92
 
93
- # 3. Encode prompt
94
- lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
95
- prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
 
 
 
 
 
96
  prompt=prompt,
97
  prompt_2=prompt_2,
98
  prompt_embeds=prompt_embeds,
@@ -102,6 +130,21 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
102
  max_sequence_length=max_sequence_length,
103
  lora_scale=lora_scale,
104
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  # 4. Prepare latent variables
106
  num_channels_latents = self.transformer.config.in_channels // 4
107
  latents, latent_image_ids = self.prepare_latents(
@@ -114,6 +157,7 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
114
  generator,
115
  latents,
116
  )
 
117
  # 5. Prepare timesteps
118
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
119
  image_seq_len = latents.shape[1]
@@ -126,43 +170,97 @@ class FLUXPipelineWithIntermediateOutputs(FluxPipeline):
126
  sigmas,
127
  mu=mu,
128
  )
 
129
  self._num_timesteps = len(timesteps)
130
 
131
- # Handle guidance
132
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
133
-
134
  # 6. Denoising loop
135
- for i, t in enumerate(timesteps):
136
- if self.interrupt:
137
- continue
138
-
139
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
140
-
141
- noise_pred = self.transformer(
142
- hidden_states=latents,
143
- timestep=timestep / 1000,
144
- guidance=guidance,
145
- pooled_projections=pooled_prompt_embeds,
146
- encoder_hidden_states=prompt_embeds,
147
- txt_ids=text_ids,
148
- img_ids=latent_image_ids,
149
- joint_attention_kwargs=self.joint_attention_kwargs,
150
- return_dict=False,
151
- )[0]
152
-
153
- # Yield intermediate result
154
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
155
- torch.cuda.empty_cache()
156
-
157
- # Final image
158
- return self._decode_latents_to_image(latents, height, width, output_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  self.maybe_free_model_hooks()
160
- torch.cuda.empty_cache()
161
 
162
- def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
 
 
 
 
 
163
  """Decodes the given latents into an image."""
164
- vae = vae or self.vae
165
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
166
- latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
167
- image = vae.decode(latents, return_dict=False)[0]
168
  return self.image_processor.postprocess(image, output_type=output_type)[0]
 
 
1
  import numpy as np
2
+ import torch
3
+ from diffusers.pipelines.flux.pipeline_output import FluxPipeline, FluxPipelineOutput
4
+ from typing import List, Union, Optional, Dict, Any, Callable
5
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
6
+ from diffusers.utils import is_torch_xla_available
7
+
8
+ if is_torch_xla_available():
9
+ import torch_xla.core.xla_model as xm
10
+
11
+ XLA_AVAILABLE = True
12
+ else:
13
+ XLA_AVAILABLE = False
14
 
15
  # Constants for shift calculation
16
  BASE_SEQ_LEN = 256
 
27
  return mu
28
 
29
  def prepare_timesteps(
30
+ scheduler,
31
  num_inference_steps: Optional[int] = None,
32
  device: Optional[Union[str, torch.device]] = None,
33
  timesteps: Optional[List[int]] = None,
 
49
  num_inference_steps = len(timesteps)
50
  return timesteps, num_inference_steps
51
 
52
+ # FLUX pipeline with CFG and intermediate outputs
53
+ class FluxWithCFGPipeline(FluxPipeline):
54
  """
55
+ Flux pipeline with Classifier-Free Guidance and the ability to yield
56
+ intermediate images during the denoising process with progressively
57
+ increasing resolution for faster generation.
58
  """
59
  @torch.inference_mode()
60
+ def __call__(
61
  self,
62
  prompt: Union[str, List[str]] = None,
63
  prompt_2: Optional[Union[str, List[str]]] = None,
64
+ negative_prompt: Optional[Union[str, List[str]]] = None,
65
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
66
  height: Optional[int] = None,
67
  width: Optional[int] = None,
68
+ num_inference_steps: int = 28,
69
  timesteps: List[int] = None,
70
  guidance_scale: float = 3.5,
71
  num_images_per_prompt: Optional[int] = 1,
 
73
  latents: Optional[torch.FloatTensor] = None,
74
  prompt_embeds: Optional[torch.FloatTensor] = None,
75
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
76
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
77
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
78
  output_type: Optional[str] = "pil",
79
  return_dict: bool = True,
80
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
81
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
82
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
83
+ max_sequence_length: int = 512,
84
+ yield_intermediates: bool = False, # New parameter for yielding intermediates
85
  ):
86
+
87
  height = height or self.default_sample_size * self.vae_scale_factor
88
  width = width or self.default_sample_size * self.vae_scale_factor
89
 
90
+ # 1. Check inputs. Raise error if not correct
91
  self.check_inputs(
92
  prompt,
93
  prompt_2,
 
95
  width,
96
  prompt_embeds=prompt_embeds,
97
  pooled_prompt_embeds=pooled_prompt_embeds,
98
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
99
  max_sequence_length=max_sequence_length,
100
  )
101
 
 
104
  self._interrupt = False
105
 
106
  # 2. Define call parameters
107
+ if prompt is not None and isinstance(prompt, str):
108
+ batch_size = 1
109
+ elif prompt is not None and isinstance(prompt, list):
110
+ batch_size = len(prompt)
111
+ else:
112
+ batch_size = prompt_embeds.shape[0]
113
+
114
  device = self._execution_device
115
 
116
+ lora_scale = (
117
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
118
+ )
119
+ (
120
+ prompt_embeds,
121
+ pooled_prompt_embeds,
122
+ text_ids,
123
+ ) = self.encode_prompt(
124
  prompt=prompt,
125
  prompt_2=prompt_2,
126
  prompt_embeds=prompt_embeds,
 
130
  max_sequence_length=max_sequence_length,
131
  lora_scale=lora_scale,
132
  )
133
+ (
134
+ negative_prompt_embeds,
135
+ negative_pooled_prompt_embeds,
136
+ negative_text_ids,
137
+ ) = self.encode_prompt(
138
+ prompt=negative_prompt,
139
+ prompt_2=negative_prompt_2,
140
+ prompt_embeds=negative_prompt_embeds,
141
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
142
+ device=device,
143
+ num_images_per_prompt=num_images_per_prompt,
144
+ max_sequence_length=max_sequence_length,
145
+ lora_scale=lora_scale,
146
+ )
147
+
148
  # 4. Prepare latent variables
149
  num_channels_latents = self.transformer.config.in_channels // 4
150
  latents, latent_image_ids = self.prepare_latents(
 
157
  generator,
158
  latents,
159
  )
160
+
161
  # 5. Prepare timesteps
162
  sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
163
  image_seq_len = latents.shape[1]
 
170
  sigmas,
171
  mu=mu,
172
  )
173
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
174
  self._num_timesteps = len(timesteps)
175
 
 
 
 
176
  # 6. Denoising loop
177
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
178
+ for i, t in enumerate(timesteps):
179
+ if self.interrupt:
180
+ continue
181
+
182
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
183
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
184
+
185
+ # handle guidance
186
+ if self.transformer.config.guidance_embeds:
187
+ guidance = torch.tensor([guidance_scale], device=device)
188
+ guidance = guidance.expand(latents.shape[0])
189
+ else:
190
+ guidance = None
191
+
192
+ noise_pred_text = self.transformer(
193
+ hidden_states=latents,
194
+ timestep=timestep / 1000,
195
+ guidance=guidance,
196
+ pooled_projections=pooled_prompt_embeds,
197
+ encoder_hidden_states=prompt_embeds,
198
+ txt_ids=text_ids,
199
+ img_ids=latent_image_ids,
200
+ joint_attention_kwargs=self.joint_attention_kwargs,
201
+ return_dict=False,
202
+ )[0]
203
+
204
+ noise_pred_uncond = self.transformer(
205
+ hidden_states=latents,
206
+ timestep=timestep / 1000,
207
+ guidance=guidance,
208
+ pooled_projections=negative_pooled_prompt_embeds,
209
+ encoder_hidden_states=negative_prompt_embeds,
210
+ txt_ids=negative_text_ids,
211
+ img_ids=latent_image_ids,
212
+ joint_attention_kwargs=self.joint_attention_kwargs,
213
+ return_dict=False,
214
+ )[0]
215
+
216
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
217
+
218
+ # compute the previous noisy sample x_t -> x_t-1
219
+ latents_dtype = latents.dtype
220
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
221
+
222
+ if latents.dtype != latents_dtype:
223
+ if torch.backends.mps.is_available():
224
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
225
+ latents = latents.to(latents_dtype)
226
+
227
+ if callback_on_step_end is not None:
228
+ callback_kwargs = {}
229
+ for k in callback_on_step_end_tensor_inputs:
230
+ callback_kwargs[k] = locals()[k]
231
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
232
+
233
+ latents = callback_outputs.pop("latents", latents)
234
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
235
+
236
+ # call the callback, if provided
237
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
238
+ progress_bar.update()
239
+
240
+ # Yield intermediate images if requested
241
+ if yield_intermediates:
242
+ yield self._decode_latents_to_image(latents, height, width, output_type)
243
+
244
+ if XLA_AVAILABLE:
245
+ xm.mark_step()
246
+
247
+ # Final image decoding
248
+ if output_type == "latent":
249
+ image = latents
250
+ else:
251
+ image = self._decode_latents_to_image(latents, height, width, output_type)
252
+
253
+ # Offload all models
254
  self.maybe_free_model_hooks()
 
255
 
256
+ if not return_dict:
257
+ return (image,)
258
+
259
+ return FluxPipelineOutput(images=image)
260
+
261
+ def _decode_latents_to_image(self, latents, height, width, output_type):
262
  """Decodes the given latents into an image."""
 
263
  latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
264
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
265
+ image = self.vae.decode(latents, return_dict=False)[0]
266
  return self.image_processor.postprocess(image, output_type=output_type)[0]