jimmycarter
commited on
Commit
•
f5b866b
1
Parent(s):
de30f79
Upload 26 files
Browse files- .gitattributes +17 -0
- README.md +263 -3
- assets/comparisons/bear.jpg +0 -0
- assets/comparisons/lady.jpg +3 -0
- assets/comparisons/lime.jpg +3 -0
- assets/comparisons/moon.jpg +0 -0
- assets/comparisons/scars.jpg +0 -0
- assets/comparisons/selfie.jpg +0 -0
- assets/comparisons/teal_woman.jpg +3 -0
- assets/comparisons/witch.jpg +3 -0
- assets/comparisons_full/comparison_0.jpg +3 -0
- assets/comparisons_full/comparison_1.jpg +3 -0
- assets/comparisons_full/comparison_10.jpg +3 -0
- assets/comparisons_full/comparison_11.jpg +3 -0
- assets/comparisons_full/comparison_12.jpg +3 -0
- assets/comparisons_full/comparison_2.jpg +3 -0
- assets/comparisons_full/comparison_3.jpg +3 -0
- assets/comparisons_full/comparison_4.jpg +3 -0
- assets/comparisons_full/comparison_5.jpg +3 -0
- assets/comparisons_full/comparison_6.jpg +3 -0
- assets/comparisons_full/comparison_7.jpg +3 -0
- assets/comparisons_full/comparison_8.jpg +3 -0
- assets/comparisons_full/comparison_9.jpg +3 -0
- assets/comparisons_full/prompts.py +54 -0
- assets/science.png +0 -0
- assets/splash.jpg +0 -0
- pipeline.py +1813 -0
.gitattributes
CHANGED
@@ -33,3 +33,20 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/comparisons_full/comparison_0.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/comparisons_full/comparison_1.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/comparisons_full/comparison_10.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/comparisons_full/comparison_11.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/comparisons_full/comparison_12.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/comparisons_full/comparison_2.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/comparisons_full/comparison_3.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/comparisons_full/comparison_4.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/comparisons_full/comparison_5.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/comparisons_full/comparison_6.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/comparisons_full/comparison_7.jpg filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/comparisons_full/comparison_8.jpg filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/comparisons_full/comparison_9.jpg filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/comparisons/lady.jpg filter=lfs diff=lfs merge=lfs -text
|
50 |
+
assets/comparisons/lime.jpg filter=lfs diff=lfs merge=lfs -text
|
51 |
+
assets/comparisons/teal_woman.jpg filter=lfs diff=lfs merge=lfs -text
|
52 |
+
assets/comparisons/witch.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,263 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LibreFLUX: A free, de-distilled FLUX model
|
2 |
+
|
3 |
+
LibreFLUX is an Apache 2.0 version of [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) that provides a full T5 context length, uses attention masking, has classifier free guidance restored, and has had most of the FLUX aesthetic finetuning/DPO fully removed. That means it's a lot uglier than base flux, but it has the potential to be more easily finetuned to any new distribution. It keeps in mind the core tenets of open source software, that it should be difficult to use, slower and clunkier than a proprietary solution, and have an aesthetic trapped somewhere inside the early 2000s.
|
4 |
+
|
5 |
+
![De-distillation t-shirt](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/splash.jpg)
|
6 |
+
|
7 |
+
> The image features a man standing confidently, wearing a simple t-shirt with a humorous and quirky message printed across the front. The t-shirt reads: "I de-distilled FLUX into a slow, ugly model and all I got was this stupid t-shirt." The man’s expression suggests a mix of pride and irony, as if he's aware of the complexity behind the statement, yet amused by the underwhelming reward. The background is neutral, keeping the focus on the man and his t-shirt, which pokes fun at the frustrating and often anticlimactic nature of technical processes or complex problem-solving, distilled into a comically understated punchline.
|
8 |
+
|
9 |
+
## Table of Contents
|
10 |
+
|
11 |
+
- [LibreFLUX: A free, de-distilled FLUX model](#libreflux-a-free-de-distilled-flux-model)
|
12 |
+
- [Usage](#usage)
|
13 |
+
- [Non-technical Report on Schnell De-distillation](#non-technical-report-on-schnell-de-distillation)
|
14 |
+
- [Why](#why)
|
15 |
+
- [Restoring the Original Training Objective](#restoring-the-original-training-objective)
|
16 |
+
- [FLUX and Attention Masking](#flux-and-attention-masking)
|
17 |
+
- [Make De-distillation Go Fast and Fit in Small GPUs](#make-de-distillation-go-fast-and-fit-in-small-gpus)
|
18 |
+
- [Selecting Better Layers to Train with LoKr](#selecting-better-layers-to-train-with-lokr)
|
19 |
+
- [Beta Timestep Scheduling and Timestep Stratification](#beta-timestep-scheduling-and-timestep-stratification)
|
20 |
+
- [Datasets](#datasets)
|
21 |
+
- [Training](#training)
|
22 |
+
- [Post-hoc "EMA"](#post-hoc-ema)
|
23 |
+
- [Results](#results)
|
24 |
+
- [Closing Thoughts](#closing-thoughts)
|
25 |
+
- [Contacting Me and Grants](#contacting-me-and-grants)
|
26 |
+
- [Citation](#citation)
|
27 |
+
|
28 |
+
# Usage
|
29 |
+
|
30 |
+
To use the model, just call the custom pipeline using [diffusers](https://github.com/huggingface/diffusers).
|
31 |
+
|
32 |
+
```py
|
33 |
+
from diffusers import DiffusionPipeline
|
34 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
35 |
+
"jimmycarter/LibreFLUX",
|
36 |
+
custom_pipeline="jimmycarter/LibreFLUX",
|
37 |
+
use_safetensors=True,
|
38 |
+
)
|
39 |
+
|
40 |
+
# High VRAM
|
41 |
+
prompt = "Photograph of a chalk board on which is written: 'I thought what I'd do was, I'd pretend I was one of those deaf-mutes.'"
|
42 |
+
negative_prompt = "blurry"
|
43 |
+
images = pipeline(
|
44 |
+
prompt=prompt,
|
45 |
+
negative_prompt=negative_prompt,
|
46 |
+
)
|
47 |
+
images[0].save('chalkboard.png')
|
48 |
+
|
49 |
+
# If you have <=24 GB VRAM, try:
|
50 |
+
# ! pip install optimum-quanto
|
51 |
+
# Then
|
52 |
+
from optimum.quanto import freeze, quantize, qint8
|
53 |
+
quantize(
|
54 |
+
pipe.transformer,
|
55 |
+
weights=qint8,
|
56 |
+
exclude=[
|
57 |
+
"*.norm", "*.norm1", "*.norm2", "*.norm2_context",
|
58 |
+
"proj_out", "x_embedder", "norm_out", "context_embedder",
|
59 |
+
],
|
60 |
+
)
|
61 |
+
freeze(pipe.transformer)
|
62 |
+
pipe.enable_model_cpu_offload()
|
63 |
+
images = pipeline(
|
64 |
+
prompt=prompt,
|
65 |
+
negative_prompt=negative_prompt,
|
66 |
+
device=None,
|
67 |
+
)
|
68 |
+
images[0].save('chalkboard.png')
|
69 |
+
```
|
70 |
+
|
71 |
+
# Non-technical Report on Schnell De-distillation
|
72 |
+
|
73 |
+
Welcome to my non-technical report on de-distilling FLUX.1-schnell in the most un-scientific way possible with extremely limited resources. I'm not going to claim I made a good model, but I did make a model. It was trained on about 1,500 H100 hour equivalents.
|
74 |
+
|
75 |
+
![Science.](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/science.png)
|
76 |
+
|
77 |
+
**Everyone is ~~an artist~~ a machine learning researcher.**
|
78 |
+
|
79 |
+
## Why
|
80 |
+
|
81 |
+
FLUX is a good text-to-image model, but the only versions of it that are out are distilled. FLUX.1-dev is distilled so that you don't need to use CFG (classifier free guidance), so instead of making one sample for conditional (your prompt) and unconditional (negative prompt), you only have to make the sample for conditional. This means that FLUX.1-dev is twice as fast as the model without distillation.
|
82 |
+
|
83 |
+
FLUX.1-schnell (German for "fast") is further distilled so that you only need 4 steps of conditional generation to get an image. Importantly, FLUX.1-schnell has an Apache-2.0 license, so you can use it freely without having to obtain a commercial license from Black Forest Labs. Out of the box, schnell is pretty bad when you use CFG unless you skip the first couple of steps.
|
84 |
+
|
85 |
+
The FLUX distilled models are created for their base, non-distilled models by [training on output from the teacher model (non-distilled) to student model (distilled) along with some tricks like an adversarial network](https://arxiv.org/abs/2403.12015).
|
86 |
+
|
87 |
+
For de-distilled models, image generation takes a little less than twice as long because you need to compute a sample for both conditional and unconditional images at each step. The benefit is you can use them commercially for free, training is a little easier, and they may be more creative.
|
88 |
+
|
89 |
+
## Restoring the original training objective
|
90 |
+
|
91 |
+
This part is actually really easy. You just train it on the normal flow-matching objective with MSE loss and the model starts learning how to do it again. That being said, I don't think either LibreFLUX or [OpenFLUX.1](https://huggingface.co/ostris/OpenFLUX.1) managed to fully de-distill the model. The evidence I see for that is that both models will either get strange shadows that overwhelm the image or blurriness when using CFG scale values greater than 4.0. Neither of us trained very long in comparison to the training for the original model (assumed to be around 0.5-2.0m H100 hours), so it's not particularly surprising.
|
92 |
+
|
93 |
+
## FLUX and attention masking
|
94 |
+
|
95 |
+
FLUX models use a text model called T5-XXL to get most of its conditioning for the text-to-image task. Importantly, they pad the text out to either 256 (schnell) or 512 (dev) tokens. 512 tokens is the maximum trained length for the model. By padding, I mean they repeat the last token until the sequence is this length.
|
96 |
+
|
97 |
+
This results in the model using these padding tokens to [store information](https://arxiv.org/abs/2309.16588). When you [visualize the attention maps of the tokens in the padding segment of the text encoder](https://github.com/kaibioinfo/FluxAttentionMap/blob/main/attentionmap.ipynb), you can see that about 10-40 tokens shortly after the last token of the text and about 10-40 tokens at the end of the padding contain information which the model uses to make images. Because these are normally used to store information, it means that any prompt long enough to not have some of these padding tokens will end up with degraded performance.
|
98 |
+
|
99 |
+
It's easy to prevent this by masking out these padding token during attention. BFL and their engineers know this, but they probably decided against it because it works as is and most fast implementations of attention only work with causal (LLM) types of padding and so would let them train faster.
|
100 |
+
|
101 |
+
I already [implemented attention masking](https://github.com/bghira/SimpleTuner/blob/main/helpers/models/flux/transformer.py#L404-L406) and I would like to be able to use all 512 tokens without degradation, so I did my finetune with it on. Small scale finetunes with it on tend to damage the model, but since I need to train so much out of distillation schnell to make it work anyway I figured it probably didn't matter to add it.
|
102 |
+
|
103 |
+
Note that FLUX.1-schnell was only trained on 256 tokens, so my finetune allows users to use the whole 512 token sequence length.
|
104 |
+
|
105 |
+
## Make de-distillation go fast and fit in small GPUs
|
106 |
+
|
107 |
+
I avoided doing any full-rank (normal, all parameters) finetuning at all, since FLUX is big. I trained initially with the model in int8 precision using [quanto](https://github.com/huggingface/optimum-quanto). I started with a 600 million parameter [LoKr](https://arxiv.org/abs/2309.14859), since LoKr tends to approximate full-rank finetuning better than LoRA. The loss was really slow to go down when I began, so after poking around the code to initialize the matrix to apply to the LoKr I settled on this function, which injects noise at a fraction of the magnitudes of the layers they apply to.
|
108 |
+
|
109 |
+
```py
|
110 |
+
def approximate_normal_tensor(inp, target, scale=1.0):
|
111 |
+
tensor = torch.randn_like(target)
|
112 |
+
desired_norm = inp.norm()
|
113 |
+
desired_mean = inp.mean()
|
114 |
+
desired_std = inp.std()
|
115 |
+
|
116 |
+
current_norm = tensor.norm()
|
117 |
+
tensor = tensor * (desired_norm / current_norm)
|
118 |
+
current_std = tensor.std()
|
119 |
+
tensor = tensor * (desired_std / current_std)
|
120 |
+
tensor = tensor - tensor.mean() + desired_mean
|
121 |
+
tensor.mul_(scale)
|
122 |
+
|
123 |
+
target.copy_(tensor)
|
124 |
+
|
125 |
+
|
126 |
+
def init_lokr_network_with_perturbed_normal(lycoris, scale=1e-3):
|
127 |
+
with torch.no_grad():
|
128 |
+
for lora in lycoris.loras:
|
129 |
+
lora.lokr_w1.fill_(1.0)
|
130 |
+
approximate_normal_tensor(lora.org_weight, lora.lokr_w2, scale=scale)
|
131 |
+
```
|
132 |
+
|
133 |
+
This isn't normal PEFT (parameter efficient fine-tuning) anymore, because this will perturb all the weights of the model slightly in the beginning. It doesn't seem to cause any performance degradation in the model after testing and it made the loss fall for my LoKr twice as fast, so I used it with `scale=1e-3`. The LoKr weights I trained in bfloat16, with the `adamw_bf16` optimizer that I ~~plagiarized~~ wrote with the magic of open source software.
|
134 |
+
|
135 |
+
## Selecting better layers to train with LoKr
|
136 |
+
|
137 |
+
FLUX is a pretty standard transformer model aside from some peculiarities. One of these peculiarities is in their "norm" layers, which contain non-linearities so they don't act like norms except for a single normalization that is applied in the layer without any weights (LayerNorm with `elementwise_affine=False`). When you fine-tune and look at what changes these layers are one of the big ones that seems to change.
|
138 |
+
|
139 |
+
The other thing about transformers is that [all the heavy lifting is most often done at the start and end layers of the network](https://arxiv.org/abs/2403.17887), so you may as well fine-tune those more than other layers. When I looked at the cosine similarity of the hidden states between each block in diffusion transformers, it more or less reflected what was observed with LLMs. So I made a pull-request to the LyCORIS repository (that maintains a LoKr implementation) that lets you more easily pick individual layers and set different factors on them, then focused my LoKr on these layers.
|
140 |
+
|
141 |
+
## Beta timestep scheduling and timestep stratification
|
142 |
+
|
143 |
+
One problem with diffusion models is that they are [multi-task](https://arxiv.org/abs/2211.01324) (different timesteps are considered different tasks) and the tasks all tend to be associated with differently shaped and sized gradients and different magnitudes of loss. This is very much not a big deal when you have a huge batch size, so the timesteps of the model all get more or less sampled evenly and the gradients are smoothed out and have less variance. I also knew that the schnell model had more problems with image distortions caused by sampling at the high-noise timesteps, so I did two things:
|
144 |
+
|
145 |
+
1. Implemented a Beta schedule that approximates the original sigmoid sampling, to let me shift the timesteps sampled to the high noise steps similar but less extreme than some of the alternative sampling methods in the SD3 research paper.
|
146 |
+
2. Implement multi-rank stratified sampling so that during each step the model trained timesteps were selected per batch based on regions, which normalizes the gradients significantly like using a higher batch size would.
|
147 |
+
|
148 |
+
```py
|
149 |
+
alpha = 2.0
|
150 |
+
beta = 2.0
|
151 |
+
num_processes = self.accelerator.num_processes
|
152 |
+
process_index = self.accelerator.process_index
|
153 |
+
total_bsz = num_processes * bsz
|
154 |
+
start_idx = process_index * bsz
|
155 |
+
end_idx = (process_index + 1) * bsz
|
156 |
+
indices = torch.arange(start_idx, end_idx, dtype=torch.float64)
|
157 |
+
u = torch.rand(bsz)
|
158 |
+
p = (indices + u) / total_bsz
|
159 |
+
sigmas = torch.from_numpy(
|
160 |
+
sp_beta.ppf(p.numpy(), a=alpha, b=beta)
|
161 |
+
).to(device=self.accelerator.device)
|
162 |
+
```
|
163 |
+
|
164 |
+
## Datasets
|
165 |
+
|
166 |
+
No one talks about what datasets they train anymore, but I used open ones from the web captioned with VLMs and 2-3 captions per image. There was at least one short and one long caption for every image. The datasets were diverse and most of them did not have aesthetic selection, which helped direct the model away from the traditional hyper-optimized image generation of text-to-image models. Many people think that looks worse, but I like that it can make a diverse pile of images. The model was trained on about 0.5 million high resolution images in both random square crops and random aspect ratio crops.
|
167 |
+
|
168 |
+
## Training
|
169 |
+
|
170 |
+
I started training for over a month on a 5x 3090s and about 500,000 images. I used a 600m LoKr for this. The model looked okay after. Then, I [unexpectedly gained access to 7x H100s for compute resources](https://rundiffusion.com), so I merged my PEFT model in and began training on a new LoKr with 3.2b parameters.
|
171 |
+
|
172 |
+
## Post-hoc "EMA"
|
173 |
+
|
174 |
+
I've been too lazy to implement real [post-hoc EMA like from EDM2](https://github.com/lucidrains/ema-pytorch/blob/main/ema_pytorch/post_hoc_ema.py), but to approximate it I saved all the checkpoints from the H100 runs and then LERPed them iteratively with different alpha values. I evaluated those checkpoints at different CFG scales to see if any of them were superior to the last checkpoint.
|
175 |
+
|
176 |
+
```py
|
177 |
+
first_checkpoint_file = checkpoint_files[0]
|
178 |
+
ema_state_dict = load_file(first_checkpoint_file)
|
179 |
+
for checkpoint_file in checkpoint_files[1:]:
|
180 |
+
new_state_dict = load_file(checkpoint_file)
|
181 |
+
for k in ema_state_dict.keys():
|
182 |
+
ema_state_dict[k] = torch.lerp(
|
183 |
+
ema_state_dict[k],
|
184 |
+
new_state_dict[k],
|
185 |
+
alpha,
|
186 |
+
)
|
187 |
+
|
188 |
+
output_file = os.path.join(output_folder, f"alpha_linear_{alpha}.safetensors")
|
189 |
+
save_file(ema_state_dict, output_file)
|
190 |
+
```
|
191 |
+
|
192 |
+
After looking at all models in alphas `[0.2, 0.4, 0.6, 0.8, 0.9, 0.95, 0.975, 0.99, 0.995, 0.999]`, I ended up settling on alpha 0.9 using the power of my eyeballs. If I am being frank, many of the EMA models looked remarkably similar and had the same kind of "rolling around various minima" qualities that training does in general.
|
193 |
+
|
194 |
+
## Results
|
195 |
+
|
196 |
+
I will go over the results briefly, but I'll start with the images.
|
197 |
+
|
198 |
+
**Figure 1.** Some side-by-side images of LibreFLUX and [OpenFLUX.1](https://huggingface.co/ostris/OpenFLUX.1). They were made using diffusers, with 512-token maximum length text embeddings for LibreFLUX and 256-token maximum length for OpenFLUX.1. LibreFLUX had attention masking on while OpenFLUX did not. The models were sampled with 35 steps at various resolutions. The negative prompt for both was simply "blurry". All inference was done with the transformer quantized to int8 by quanto.
|
199 |
+
|
200 |
+
![Polar bear](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/bear.jpg)
|
201 |
+
|
202 |
+
> A cinematic style shot of a polar bear standing confidently in the center of a vibrant nightclub. The bear is holding a large sign that reads 'Open Source! Apache 2.0' in one arm and giving a thumbs up with the other arm. Around him, the club is alive with energy as colorful lasers and disco lights illuminate the scene. People are dancing all around him, wearing glowsticks and candy bracelets, adding to the fun and electric atmosphere. The polar bear's white fur contrasts against the dark, neon-lit background, and the entire scene has a surreal, festive vibe, blending technology activism with a lively party environment.
|
203 |
+
|
204 |
+
![Artistic picture of woman](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/lady.jpg)
|
205 |
+
|
206 |
+
> widescreen, vintage style from 1970s, Extreme realism in a complex, highly detailed composition featuring a woman with extremely long flowing rainbow-colored hair. The glowing background, with its vibrant colors, exaggerated details, intricate textures, and dynamic lighting, creates a whimsical, dreamy atmosphere in photorealistic quality. Threads of light that float and weave through the air, adding movement and intrigue. Patterns on the ground or in the background that glow subtly, adding a layer of complexity.Rainbows that appear faintly in the background, adding a touch of color and wonder.Butterfly wings that shimmer in the light, adding life and movement to the scene.Beams of light that radiate softly through the scene, adding focus and direction. The woman looks away from the camera, with a soft, wistful expression, her hair framing her face.
|
207 |
+
|
208 |
+
![Western movie poster](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/lime.jpg)
|
209 |
+
|
210 |
+
> a highly detailed and atmospheric, painted western movie poster with the title text "Once Upon a Lime in the West" in a dark red western-style font and the tagline text "There were three men ... and one very sour twist", with movie credits at the bottom, featuring small white text detailing actor and director names and production company logos, inspired by classic western movie posters from the 1960s, an oversized lime is the central element in the middle ground of a rugged, sun-scorched desert landscape typical of a western, the vast expanse of dry, cracked earth stretches toward the horizon, framed by towering red rock formations, the absurdity of the lime is juxtaposed with the intense gravitas of the stoic, iconic gunfighters, as if the lime were as formidable an adversary as any seasoned gunslinger, in the foreground, the silhouettes of two iconic gunfighters stand poised, facing the lime and away from the viewer, the lime looms in the distance like a final showdown in the classic western tradition, in the foreground, the gunfighters stand with long duster coats flowing in the wind, and wide-brimmed hats tilted to cast shadows over their faces, their stances are tense, as if ready for the inevitable draw, and the weapons they carry glint, the background consists of the distant town, where the sun is casting a golden glow, old wooden buildings line the sides, with horses tied to posts and a weathered saloon sign swinging gently in the wind, in this poster, the lime plays the role of the silent villain, an almost mythical object that the gunfighters are preparing to confront, the tension of the scene is palpable, the gunfighters in the foreground have faces marked by dust and sweat, their eyes narrowed against the bright sunlight, their expressions are serious and resolute, as if they have come a long way for this final duel, the absurdity of the lime is in stark contrast with their stoic demeanor, a wide, panoramic shot captures the entire scene, with the gunfighters in the foreground, the lime in the mid-ground, and the town on the horizon, the framing emphasizes the scale of the desert and the dramatic standoff taking place, while subtly highlighting the oversized lime, the camera is positioned low, angled upward from the dusty ground toward the gunfighters, with the distant lime looming ahead, this angle lends the figures an imposing presence, while still giving the lime an absurd grandeur in the distance, the perspective draws the viewerâs eye across the desert, from the silhouettes of the gunfighters to the bizarre focal point of the lime, amplifying the tension, the lighting is harsh and unforgiving, typical of a desert setting, with the evening sun casting deep shadows across the ground, dust clouds drift subtly across the ground, creating a hazy effect, while the sky above is a vast expanse of pale blue, fading into golden hues near the horizon where the sun begins to set, the poster is shot as if using classic anamorphic lenses to capture the wide, epic scale of the desert, the color palette is warm and saturated, evoking the look of a classic spaghetti western, the lime looms unnaturally in the distance, as if conjured from the land itself, casting an absurdly grand shadow across the rugged landscape, the texture and detail evoke hand-painted, weathered posters from the golden age of westerns, with slightly frayed edges and faint creases mimicking the wear of vintage classics
|
211 |
+
|
212 |
+
![Witch action figure](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/witch.jpg)
|
213 |
+
|
214 |
+
> A boxed action figure of a beautiful elf girl witch wearing a skimpy black leotard, black thigh highs, black armlets, and a short black cloak. Her hair is pink and shoulder-length. Her eyes are green. She is a slim and attractive elf with small breasts. The accessories include an apple, magic wand, potion bottle, black cat, jack o lantern, and a book. The box is orange and black with a logo near the bottom of it that says "BAD WITCH". The box is on a shelf on the toy aisle.
|
215 |
+
|
216 |
+
![Photograph of woman in teal room with dog](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/teal_woman.jpg)
|
217 |
+
|
218 |
+
> A cute blonde woman in bikini and her doge are sitting on a couch cuddling and the expressive, stylish living room scene with a playful twist. The room is painted in a soothing turquoise color scheme, stylish living room scene bathed in a cool, textured turquoise blanket and adorned with several matching turquoise throw pillows. The room's color scheme is predominantly turquoise, relaxed demeanor. The couch is covered in a soft, reflecting light and adding to the vibrant blue hue., dark room with a sleek, spherical gold decorations, This photograph captures a scene that is whimsically styled in a vibrant, reflective cyan sunglasses. The dog's expression is cheerful, metallic fabric sofa. The dog, soothing atmosphere.
|
219 |
+
|
220 |
+
![Selfie of a man and woman](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/selfie.jpg)
|
221 |
+
|
222 |
+
> Selfie of a woman in front of the eiffel tower, a man is standing next to her and giving a thumbs up
|
223 |
+
|
224 |
+
![Image of just text](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/scars.jpg)
|
225 |
+
|
226 |
+
> An image contains three motivational phrases, all in capitalized stylized text on a colorful background: 1. At the top: "PAIN HEALS" 2. In the middle, bold and slightly larger: "CHICKS DIG SCARS" 3. At the bottom: "GLORY LASTS FOREVER"
|
227 |
+
|
228 |
+
![Digital art with lots of details specified of McDonald's on the moon](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons/moon.jpg)
|
229 |
+
|
230 |
+
> An illustration featuring a McDonald's on the moon. An anthropomorphic cat in a pink top and blue jeans is ordering McDonald's, while a zebra cashier stands behind the counter. The moon's surface is visible outside the windows, with craters and a distant view of Earth. The interior of the McDonald's is similar to those on Earth but adapted to the lunar environment, with vibrant colors and futuristic design elements. The overall scene is whimsical and imaginative, blending everyday life with a fantastical setting.
|
231 |
+
|
232 |
+
LibreFLUX and OpenFLUX have their strengths and weaknesses. OpenFLUX was de-distilled using the outputs of FLUX.1-schnell, which might explain why it's worse at text but also has the FLUX hyperaesthetics. Text-to-image models [don't have any good metrics](https://arxiv.org/abs/2306.04675) so past a point of "soupiness" and single digit FID you just need to look at the model and see if it fits what you think nice pictures are.
|
233 |
+
|
234 |
+
Both models appear to be terrible at making drawings. Because people are probably curious to see the non-cherry picks, [I've included CFG sweep comparisons of both LibreFLUX and OpenFLUX.1 here](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/comparisons_full/). I'm not going to say this is the best model ever, but it might be a springboard for people wanting to finetune better models from.
|
235 |
+
|
236 |
+
## Closing thoughts
|
237 |
+
|
238 |
+
If I had to do it again, I'd probably raise the learning rate more on the H100 run. There was a [bug in SimpleTuner](https://github.com/bghira/SimpleTuner/issues/1064) that caused me to not use the [initialization trick](#make-de-distillation-go-fast-and-fit-in-small-gpus) when on the H100s, then [timestep stratification](#beta-timestep-scheduling-and-timestep-stratification) ended up quieting down the gradient magnitudes even more and caused the model to learn very slowly at `1e-5`. I realized this when looking at the results of EMA on the final FLUX.1-dev. The H100s really came out of nowhere as I just got an IP address to shell into late one night around 10PM and ended up staying up all night to get everything running, so in the future I'm sure I would be more prepared.
|
239 |
+
|
240 |
+
For de-distillation of schnell I think you probably need a lot more than 1500 H100-equivalent hours. I am very tired of training FLUX and am looking forward to a better model with less parameters. The model learns new concepts slowly when given piles of well labeled data. Given the history of LLMs, we now have models like LLaMA 3.1 8B that trade blows with GPT3.5 175B and I am hopeful that the future holds [smaller, faster models that look better](https://openreview.net/pdf?id=jQP5o1VAVc).
|
241 |
+
|
242 |
+
As far as what I think of the FLUX "open source", many models being trained and released today are attempts at raising VC cash and I have noticed a mountain of them being promoted on Twitter. Since [a16z poached the entire SD3 dev team from Stability.ai](https://siliconcanals.com/black-forest-labs-secures-28m/) the field feels more toxic than ever, but I am hopeful for individuals and research labs to selflessly lead the path forward for open weights. I made zero dollars on this and have made zero dollars on ML to date, but I try to make contributions where I can.
|
243 |
+
|
244 |
+
![The state of open source](https://huggingface.co/jimmycarter/LibreFLUX/blob/main/assets/opensource.png)
|
245 |
+
|
246 |
+
I would like to thank [RunDiffusion](https://rundiffusion.com) for the H100 access.
|
247 |
+
|
248 |
+
## Contacting me and grants
|
249 |
+
|
250 |
+
You can contact me by opening an issue on the discuss page of this model. If you want to speak privately about grants because you want me to continue training this or give me a means to conduct reproducible research, leave an email address too.
|
251 |
+
|
252 |
+
## Citation
|
253 |
+
|
254 |
+
```
|
255 |
+
@misc{libreflux,
|
256 |
+
author = {James Carter},
|
257 |
+
title = {LibreFLUX: A free, de-distilled FLUX model},
|
258 |
+
year = {2024},
|
259 |
+
publisher = {Huggingface},
|
260 |
+
journal = {Huggingface repository},
|
261 |
+
howpublished = {\url{https://huggingface.co/datasets/jimmycarter/libreflux}},
|
262 |
+
}
|
263 |
+
```
|
assets/comparisons/bear.jpg
ADDED
assets/comparisons/lady.jpg
ADDED
Git LFS Details
|
assets/comparisons/lime.jpg
ADDED
Git LFS Details
|
assets/comparisons/moon.jpg
ADDED
assets/comparisons/scars.jpg
ADDED
assets/comparisons/selfie.jpg
ADDED
assets/comparisons/teal_woman.jpg
ADDED
Git LFS Details
|
assets/comparisons/witch.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_0.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_1.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_10.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_11.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_12.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_2.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_3.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_4.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_5.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_6.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_7.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_8.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/comparison_9.jpg
ADDED
Git LFS Details
|
assets/comparisons_full/prompts.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompts = [
|
2 |
+
( # 0
|
3 |
+
"A wide format poster featuring George Washington atop a glorious bald eagle with its wings spread flying through the sky, background is the American flag and fireworks. Huge shiny Red white and blue, bold, gradient letters at the bottom spelling out \"WTF IS A KILOMETER\" in flaming text. 4k, masterpiece.",
|
4 |
+
(1536, 1024),
|
5 |
+
),
|
6 |
+
( # 1
|
7 |
+
"widescreen, vintage style from 1970s, Extreme realism in a complex, highly detailed composition featuring a woman with extremely long flowing rainbow-colored hair. The glowing background, with its vibrant colors, exaggerated details, intricate textures, and dynamic lighting, creates a whimsical, dreamy atmosphere in photorealistic quality. Threads of light that float and weave through the air, adding movement and intrigue. Patterns on the ground or in the background that glow subtly, adding a layer of complexity.Rainbows that appear faintly in the background, adding a touch of color and wonder.Butterfly wings that shimmer in the light, adding life and movement to the scene.Beams of light that radiate softly through the scene, adding focus and direction. The woman looks away from the camera, with a soft, wistful expression, her hair framing her face. ",
|
8 |
+
(1536, 1024),
|
9 |
+
),
|
10 |
+
( # 2
|
11 |
+
'A cinematic style shot of a polar bear standing confidently in the center of a vibrant nightclub. The bear is holding a large sign that reads \'Open Source! Apache 2.0\' in one arm and giving a thumbs up with the other arm. Around him, the club is alive with energy as colorful lasers and disco lights illuminate the scene. People are dancing all around him, wearing glowsticks and candy bracelets, adding to the fun and electric atmosphere. The polar bear\'s white fur contrasts against the dark, neon-lit background, and the entire scene has a surreal, festive vibe, blending technology activism with a lively party environment.',
|
12 |
+
(1536, 1024),
|
13 |
+
),
|
14 |
+
( # 3
|
15 |
+
'A boxed action figure of a beautiful elf girl witch wearing a skimpy black leotard, black thigh highs, black armlets, and a short black cloak. Her hair is pink and shoulder-length. Her eyes are green. She is a slim and attractive elf with small breasts. The accessories include an apple, magic wand, potion bottle, black cat, jack o lantern, and a book. The box is orange and black with a logo near the bottom of it that says "BAD WITCH". The box is on a shelf on the toy aisle.',
|
16 |
+
(1024, 1536),
|
17 |
+
),
|
18 |
+
( # 4
|
19 |
+
"A cute blonde woman in bikini and her doge are sitting on a couch cuddling and the expressive, stylish living room scene with a playful twist. The room is painted in a soothing turquoise color scheme, stylish living room scene bathed in a cool, textured turquoise blanket and adorned with several matching turquoise throw pillows. The room's color scheme is predominantly turquoise, relaxed demeanor. The couch is covered in a soft, reflecting light and adding to the vibrant blue hue., dark room with a sleek, spherical gold decorations, This photograph captures a scene that is whimsically styled in a vibrant, reflective cyan sunglasses. The dog's expression is cheerful, metallic fabric sofa. The dog, soothing atmosphere.",
|
20 |
+
(1536, 1024),
|
21 |
+
),
|
22 |
+
( # 5
|
23 |
+
"Bioluminescent, A hyperrealistic depiction of a surreal scene: a piano keyboard morphs into a spiral staircase, ascending into a swirling vortex of golden, autumnal hues. A figure with a porcelain mask, reminiscent of commedia dell'arte, emerges from beneath the keys, their hand extended towards a lone female figure in a flowing gown at the apex of the staircase. Emphasize the juxtaposition of the organic and geometric, the tangible and ethereal, with a chiaroscuro lighting style. Capture the melancholic beauty and enigmatic narrative inherent in the scene.",
|
24 |
+
(1024, 1536),
|
25 |
+
),
|
26 |
+
( # 6
|
27 |
+
"highly detailed cinematic movie poster with the text \"PACIFIST RIM\" in a bold, vibrant sci-fi-style font at the top and a tagline reading \"Saving the world, one bouquet at a time\" below it, with the movie credits at the bottom, in the foreground, a gigantic tailless humanoid bipedal mecha-robot and an equally massive kaiju with blue-green iridescent scales and bioluminescent accents stand face-to-face, the enormous mecha on the left is clad in battle-worn yet gleaming metallic armor plates, holding out a large bouquet of exotic, colorful flowers to the kaiju on the right, the kaiju looks surprised by the gesture, its grotesque, otherworldly, surreal form equipped with a row of cyan glowing crystalline spikes along its back, the setting is an urban waterfront, framed by towering skyscrapers and a shimmering ocean with soft waves behind them, the background is bathed in the glow of moonlight and flickering neon billboards with messages like \"HARMONY\" and \"PEACE,\" tiny people on the ground below snap photos with their phones, while some onlookers stare in disbelief, behind the two titanic figures, the calm ocean glistens under the moonlight as distant ships drift by, the robot's posture is calm and serene, its large claw-like hands extended as it presents the bouquet, both figures express an aura of peace and harmony despite their intimidating size, the monstrous kaiju, though menacing, is curious and seemingly receptive to the offering, the overall mood is tranquil, a wide-angle shot captures both the robot and kaiju in their full, towering forms, emphasizing their colossal scale against the peaceful cityscape backdrop, the waterfront and moonlight add depth, while a low-angle shot looking up at the robot and kaiju further enhances their imposing size, the lighting is vibrant and saturated, with soft moonlight reflecting off the robotâs metallic surface, neon city lights providing colorful accents, lens flare, dappled lighting",
|
28 |
+
(1024, 1536),
|
29 |
+
),
|
30 |
+
( # 7
|
31 |
+
'a highly detailed and atmospheric, painted western movie poster with the title text "Once Upon a Lime in the West" in a dark red western-style font and the tagline text "There were three men ... and one very sour twist", with movie credits at the bottom, featuring small white text detailing actor and director names and production company logos, inspired by classic western movie posters from the 1960s, an oversized lime is the central element in the middle ground of a rugged, sun-scorched desert landscape typical of a western, the vast expanse of dry, cracked earth stretches toward the horizon, framed by towering red rock formations, the absurdity of the lime is juxtaposed with the intense gravitas of the stoic, iconic gunfighters, as if the lime were as formidable an adversary as any seasoned gunslinger, in the foreground, the silhouettes of two iconic gunfighters stand poised, facing the lime and away from the viewer, the lime looms in the distance like a final showdown in the classic western tradition, in the foreground, the gunfighters stand with long duster coats flowing in the wind, and wide-brimmed hats tilted to cast shadows over their faces, their stances are tense, as if ready for the inevitable draw, and the weapons they carry glint, the background consists of the distant town, where the sun is casting a golden glow, old wooden buildings line the sides, with horses tied to posts and a weathered saloon sign swinging gently in the wind, in this poster, the lime plays the role of the silent villain, an almost mythical object that the gunfighters are preparing to confront, the tension of the scene is palpable, the gunfighters in the foreground have faces marked by dust and sweat, their eyes narrowed against the bright sunlight, their expressions are serious and resolute, as if they have come a long way for this final duel, the absurdity of the lime is in stark contrast with their stoic demeanor, a wide, panoramic shot captures the entire scene, with the gunfighters in the foreground, the lime in the mid-ground, and the town on the horizon, the framing emphasizes the scale of the desert and the dramatic standoff taking place, while subtly highlighting the oversized lime, the camera is positioned low, angled upward from the dusty ground toward the gunfighters, with the distant lime looming ahead, this angle lends the figures an imposing presence, while still giving the lime an absurd grandeur in the distance, the perspective draws the viewerâs eye across the desert, from the silhouettes of the gunfighters to the bizarre focal point of the lime, amplifying the tension, the lighting is harsh and unforgiving, typical of a desert setting, with the evening sun casting deep shadows across the ground, dust clouds drift subtly across the ground, creating a hazy effect, while the sky above is a vast expanse of pale blue, fading into golden hues near the horizon where the sun begins to set, the poster is shot as if using classic anamorphic lenses to capture the wide, epic scale of the desert, the color palette is warm and saturated, evoking the look of a classic spaghetti western, the lime looms unnaturally in the distance, as if conjured from the land itself, casting an absurdly grand shadow across the rugged landscape, the texture and detail evoke hand-painted, weathered posters from the golden age of westerns, with slightly frayed edges and faint creases mimicking the wear of vintage classics',
|
32 |
+
(1024, 1536),
|
33 |
+
),
|
34 |
+
( # 8
|
35 |
+
'Anime illustration of a man standing next to a cat',
|
36 |
+
(1024, 1024),
|
37 |
+
),
|
38 |
+
( # 9
|
39 |
+
'Selfie of a woman in front of the eiffel tower, a man is standing next to her and giving a thumbs up',
|
40 |
+
(1024, 1024),
|
41 |
+
),
|
42 |
+
( # 10
|
43 |
+
"a life-sized, clear plastic action figure box with a real woman trapped inside. The box has vibrant, eye-catching colors, featuring bold logos and text reminiscent of classic action figure packaging. The woman stands stiffly in the middle, her pose rigid like a doll, her facial expression conveying a mixture of confusion and surprise. She wears a brightly colored outfit that matches the action figure aesthetic, with exaggerated accessories like a toy sword or futuristic helmet strapped to her side. The box\’s background features bold comic-book-like artwork, framing the woman with dynamic lines and cartoonish explosions, emphasizing the \"action\" theme. The plastic window on the front covers the woman\’s entire body, while the sides display branding and promotional text, like \“Superhero Edition\” or \“Ultimate Collector\’s Item!\” Around the box, toy-like details abound: barcodes, toy company logos, and descriptions of her \“powers\” or \“abilities\” written in comic-style font.",
|
44 |
+
(1024, 1536),
|
45 |
+
),
|
46 |
+
( # 11
|
47 |
+
'An image contains three motivational phrases, all in capitalized stylized text on a colorful background: 1. At the top: "PAIN HEALS" 2. In the middle, bold and slightly larger: "CHICKS DIG SCARS" 3. At the bottom: "GLORY LASTS FOREVER"',
|
48 |
+
(1024, 1024),
|
49 |
+
),
|
50 |
+
( # 12
|
51 |
+
'An illustration featuring a McDonald\'s on the moon. An anthropomorphic cat in a pink top and blue jeans is ordering McDonald\'s, while a zebra cashier stands behind the counter. The moon\'s surface is visible outside the windows, with craters and a distant view of Earth. The interior of the McDonald\'s is similar to those on Earth but adapted to the lunar environment, with vibrant colors and futuristic design elements. The overall scene is whimsical and imaginative, blending everyday life with a fantastical setting.',
|
52 |
+
(1024, 1024),
|
53 |
+
),
|
54 |
+
]
|
assets/science.png
ADDED
assets/splash.jpg
ADDED
pipeline.py
ADDED
@@ -0,0 +1,1813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
|
2 |
+
#
|
3 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
#
|
17 |
+
# Originally licensed under the Apache License, Version 2.0 (the "License");
|
18 |
+
# Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
|
19 |
+
|
20 |
+
from typing import Any, Dict, List, Optional, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
28 |
+
from diffusers.models.attention import FeedForward
|
29 |
+
from diffusers.models.attention_processor import (
|
30 |
+
Attention,
|
31 |
+
apply_rope,
|
32 |
+
)
|
33 |
+
from diffusers.models.modeling_utils import ModelMixin
|
34 |
+
from diffusers.models.normalization import (
|
35 |
+
AdaLayerNormContinuous,
|
36 |
+
AdaLayerNormZero,
|
37 |
+
AdaLayerNormZeroSingle,
|
38 |
+
)
|
39 |
+
from diffusers.utils import (
|
40 |
+
USE_PEFT_BACKEND,
|
41 |
+
is_torch_version,
|
42 |
+
logging,
|
43 |
+
scale_lora_layers,
|
44 |
+
unscale_lora_layers,
|
45 |
+
)
|
46 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
47 |
+
from diffusers.models.embeddings import (
|
48 |
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
49 |
+
CombinedTimestepTextProjEmbeddings,
|
50 |
+
)
|
51 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
52 |
+
|
53 |
+
from dataclasses import dataclass
|
54 |
+
from typing import List, Union
|
55 |
+
import PIL.Image
|
56 |
+
from diffusers.utils import BaseOutput
|
57 |
+
|
58 |
+
import inspect
|
59 |
+
from functools import lru_cache
|
60 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
61 |
+
|
62 |
+
import numpy as np
|
63 |
+
import torch
|
64 |
+
from transformers import (
|
65 |
+
CLIPTextModel,
|
66 |
+
CLIPTokenizer,
|
67 |
+
T5EncoderModel,
|
68 |
+
T5TokenizerFast,
|
69 |
+
)
|
70 |
+
|
71 |
+
from diffusers.image_processor import VaeImageProcessor
|
72 |
+
from diffusers.loaders import SD3LoraLoaderMixin
|
73 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
74 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
75 |
+
from diffusers.utils import (
|
76 |
+
USE_PEFT_BACKEND,
|
77 |
+
is_torch_xla_available,
|
78 |
+
logging,
|
79 |
+
replace_example_docstring,
|
80 |
+
scale_lora_layers,
|
81 |
+
unscale_lora_layers,
|
82 |
+
)
|
83 |
+
from diffusers.utils.torch_utils import randn_tensor
|
84 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
85 |
+
|
86 |
+
if is_torch_xla_available():
|
87 |
+
import torch_xla.core.xla_model as xm
|
88 |
+
|
89 |
+
XLA_AVAILABLE = True
|
90 |
+
else:
|
91 |
+
XLA_AVAILABLE = False
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class FluxPipelineOutput(BaseOutput):
|
96 |
+
"""
|
97 |
+
Output class for Stable Diffusion pipelines.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
101 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
102 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
103 |
+
"""
|
104 |
+
|
105 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
106 |
+
|
107 |
+
|
108 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
109 |
+
|
110 |
+
|
111 |
+
class FluxSingleAttnProcessor2_0:
|
112 |
+
r"""
|
113 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self):
|
117 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
118 |
+
raise ImportError(
|
119 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
120 |
+
)
|
121 |
+
|
122 |
+
def __call__(
|
123 |
+
self,
|
124 |
+
attn: Attention,
|
125 |
+
hidden_states: torch.Tensor,
|
126 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
127 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
128 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
129 |
+
) -> torch.Tensor:
|
130 |
+
input_ndim = hidden_states.ndim
|
131 |
+
|
132 |
+
if input_ndim == 4:
|
133 |
+
batch_size, channel, height, width = hidden_states.shape
|
134 |
+
hidden_states = hidden_states.view(
|
135 |
+
batch_size, channel, height * width
|
136 |
+
).transpose(1, 2)
|
137 |
+
|
138 |
+
batch_size, _, _ = hidden_states.shape
|
139 |
+
query = attn.to_q(hidden_states)
|
140 |
+
key = attn.to_k(hidden_states)
|
141 |
+
value = attn.to_v(hidden_states)
|
142 |
+
|
143 |
+
inner_dim = key.shape[-1]
|
144 |
+
head_dim = inner_dim // attn.heads
|
145 |
+
|
146 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
147 |
+
|
148 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
149 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
150 |
+
|
151 |
+
if attn.norm_q is not None:
|
152 |
+
query = attn.norm_q(query)
|
153 |
+
if attn.norm_k is not None:
|
154 |
+
key = attn.norm_k(key)
|
155 |
+
|
156 |
+
# Apply RoPE if needed
|
157 |
+
if image_rotary_emb is not None:
|
158 |
+
# YiYi to-do: update uising apply_rotary_emb
|
159 |
+
# from ..embeddings import apply_rotary_emb
|
160 |
+
# query = apply_rotary_emb(query, image_rotary_emb)
|
161 |
+
# key = apply_rotary_emb(key, image_rotary_emb)
|
162 |
+
query, key = apply_rope(query, key, image_rotary_emb)
|
163 |
+
|
164 |
+
if attention_mask is not None:
|
165 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
166 |
+
attention_mask = (attention_mask > 0).bool()
|
167 |
+
attention_mask = attention_mask.to(
|
168 |
+
device=hidden_states.device, dtype=hidden_states.dtype
|
169 |
+
)
|
170 |
+
|
171 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
172 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
173 |
+
hidden_states = F.scaled_dot_product_attention(
|
174 |
+
query,
|
175 |
+
key,
|
176 |
+
value,
|
177 |
+
dropout_p=0.0,
|
178 |
+
is_causal=False,
|
179 |
+
attn_mask=attention_mask,
|
180 |
+
)
|
181 |
+
|
182 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
183 |
+
batch_size, -1, attn.heads * head_dim
|
184 |
+
)
|
185 |
+
hidden_states = hidden_states.to(query.dtype)
|
186 |
+
|
187 |
+
if input_ndim == 4:
|
188 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
189 |
+
batch_size, channel, height, width
|
190 |
+
)
|
191 |
+
|
192 |
+
return hidden_states
|
193 |
+
|
194 |
+
|
195 |
+
class FluxAttnProcessor2_0:
|
196 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
197 |
+
|
198 |
+
def __init__(self):
|
199 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
200 |
+
raise ImportError(
|
201 |
+
"FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
202 |
+
)
|
203 |
+
|
204 |
+
def __call__(
|
205 |
+
self,
|
206 |
+
attn: Attention,
|
207 |
+
hidden_states: torch.FloatTensor,
|
208 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
209 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
210 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
211 |
+
) -> torch.FloatTensor:
|
212 |
+
input_ndim = hidden_states.ndim
|
213 |
+
if input_ndim == 4:
|
214 |
+
batch_size, channel, height, width = hidden_states.shape
|
215 |
+
hidden_states = hidden_states.view(
|
216 |
+
batch_size, channel, height * width
|
217 |
+
).transpose(1, 2)
|
218 |
+
context_input_ndim = encoder_hidden_states.ndim
|
219 |
+
if context_input_ndim == 4:
|
220 |
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
221 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
222 |
+
batch_size, channel, height * width
|
223 |
+
).transpose(1, 2)
|
224 |
+
|
225 |
+
batch_size = encoder_hidden_states.shape[0]
|
226 |
+
|
227 |
+
# `sample` projections.
|
228 |
+
query = attn.to_q(hidden_states)
|
229 |
+
key = attn.to_k(hidden_states)
|
230 |
+
value = attn.to_v(hidden_states)
|
231 |
+
|
232 |
+
inner_dim = key.shape[-1]
|
233 |
+
head_dim = inner_dim // attn.heads
|
234 |
+
|
235 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
236 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
237 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
238 |
+
|
239 |
+
if attn.norm_q is not None:
|
240 |
+
query = attn.norm_q(query)
|
241 |
+
if attn.norm_k is not None:
|
242 |
+
key = attn.norm_k(key)
|
243 |
+
|
244 |
+
# `context` projections.
|
245 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
246 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
247 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
248 |
+
|
249 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
250 |
+
batch_size, -1, attn.heads, head_dim
|
251 |
+
).transpose(1, 2)
|
252 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
253 |
+
batch_size, -1, attn.heads, head_dim
|
254 |
+
).transpose(1, 2)
|
255 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
256 |
+
batch_size, -1, attn.heads, head_dim
|
257 |
+
).transpose(1, 2)
|
258 |
+
|
259 |
+
if attn.norm_added_q is not None:
|
260 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(
|
261 |
+
encoder_hidden_states_query_proj
|
262 |
+
)
|
263 |
+
if attn.norm_added_k is not None:
|
264 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(
|
265 |
+
encoder_hidden_states_key_proj
|
266 |
+
)
|
267 |
+
|
268 |
+
# attention
|
269 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
270 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
271 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
272 |
+
|
273 |
+
if image_rotary_emb is not None:
|
274 |
+
# YiYi to-do: update uising apply_rotary_emb
|
275 |
+
# from ..embeddings import apply_rotary_emb
|
276 |
+
# query = apply_rotary_emb(query, image_rotary_emb)
|
277 |
+
# key = apply_rotary_emb(key, image_rotary_emb)
|
278 |
+
query, key = apply_rope(query, key, image_rotary_emb)
|
279 |
+
|
280 |
+
if attention_mask is not None:
|
281 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
282 |
+
attention_mask = (attention_mask > 0).bool()
|
283 |
+
attention_mask = attention_mask.to(
|
284 |
+
device=hidden_states.device, dtype=hidden_states.dtype
|
285 |
+
)
|
286 |
+
|
287 |
+
hidden_states = F.scaled_dot_product_attention(
|
288 |
+
query,
|
289 |
+
key,
|
290 |
+
value,
|
291 |
+
dropout_p=0.0,
|
292 |
+
is_causal=False,
|
293 |
+
attn_mask=attention_mask,
|
294 |
+
)
|
295 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
296 |
+
batch_size, -1, attn.heads * head_dim
|
297 |
+
)
|
298 |
+
hidden_states = hidden_states.to(query.dtype)
|
299 |
+
|
300 |
+
encoder_hidden_states, hidden_states = (
|
301 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
302 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
303 |
+
)
|
304 |
+
|
305 |
+
# linear proj
|
306 |
+
hidden_states = attn.to_out[0](hidden_states)
|
307 |
+
# dropout
|
308 |
+
hidden_states = attn.to_out[1](hidden_states)
|
309 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
310 |
+
|
311 |
+
if input_ndim == 4:
|
312 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
313 |
+
batch_size, channel, height, width
|
314 |
+
)
|
315 |
+
if context_input_ndim == 4:
|
316 |
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
|
317 |
+
batch_size, channel, height, width
|
318 |
+
)
|
319 |
+
|
320 |
+
return hidden_states, encoder_hidden_states
|
321 |
+
|
322 |
+
|
323 |
+
# YiYi to-do: refactor rope related functions/classes
|
324 |
+
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
325 |
+
assert dim % 2 == 0, "The dimension must be even."
|
326 |
+
|
327 |
+
scale = (
|
328 |
+
torch.arange(
|
329 |
+
0,
|
330 |
+
dim,
|
331 |
+
2,
|
332 |
+
dtype=torch.float64, # torch.float32 if torch.backends.mps.is_available() else
|
333 |
+
device=pos.device,
|
334 |
+
)
|
335 |
+
/ dim
|
336 |
+
)
|
337 |
+
omega = 1.0 / (theta**scale)
|
338 |
+
|
339 |
+
batch_size, seq_length = pos.shape
|
340 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
341 |
+
cos_out = torch.cos(out)
|
342 |
+
sin_out = torch.sin(out)
|
343 |
+
|
344 |
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
345 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
346 |
+
return out.float()
|
347 |
+
|
348 |
+
|
349 |
+
# YiYi to-do: refactor rope related functions/classes
|
350 |
+
class EmbedND(nn.Module):
|
351 |
+
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
352 |
+
super().__init__()
|
353 |
+
self.dim = dim
|
354 |
+
self.theta = theta
|
355 |
+
self.axes_dim = axes_dim
|
356 |
+
|
357 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
358 |
+
n_axes = ids.shape[-1]
|
359 |
+
emb = torch.cat(
|
360 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
361 |
+
dim=-3,
|
362 |
+
)
|
363 |
+
|
364 |
+
return emb.unsqueeze(1)
|
365 |
+
|
366 |
+
|
367 |
+
def expand_flux_attention_mask(
|
368 |
+
hidden_states: torch.Tensor,
|
369 |
+
attn_mask: torch.Tensor,
|
370 |
+
) -> torch.Tensor:
|
371 |
+
"""
|
372 |
+
Expand a mask so that the image is included.
|
373 |
+
"""
|
374 |
+
bsz = attn_mask.shape[0]
|
375 |
+
assert bsz == hidden_states.shape[0]
|
376 |
+
residual_seq_len = hidden_states.shape[1]
|
377 |
+
mask_seq_len = attn_mask.shape[1]
|
378 |
+
|
379 |
+
expanded_mask = torch.ones(bsz, residual_seq_len)
|
380 |
+
expanded_mask[:, :mask_seq_len] = attn_mask
|
381 |
+
|
382 |
+
return expanded_mask
|
383 |
+
|
384 |
+
|
385 |
+
@maybe_allow_in_graph
|
386 |
+
class FluxSingleTransformerBlock(nn.Module):
|
387 |
+
r"""
|
388 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
389 |
+
|
390 |
+
Reference: https://arxiv.org/abs/2403.03206
|
391 |
+
|
392 |
+
Parameters:
|
393 |
+
dim (`int`): The number of channels in the input and output.
|
394 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
395 |
+
attention_head_dim (`int`): The number of channels in each head.
|
396 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
397 |
+
processing of `context` conditions.
|
398 |
+
"""
|
399 |
+
|
400 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
401 |
+
super().__init__()
|
402 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
403 |
+
|
404 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
405 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
406 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
407 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
408 |
+
|
409 |
+
processor = FluxSingleAttnProcessor2_0()
|
410 |
+
self.attn = Attention(
|
411 |
+
query_dim=dim,
|
412 |
+
cross_attention_dim=None,
|
413 |
+
dim_head=attention_head_dim,
|
414 |
+
heads=num_attention_heads,
|
415 |
+
out_dim=dim,
|
416 |
+
bias=True,
|
417 |
+
processor=processor,
|
418 |
+
qk_norm="rms_norm",
|
419 |
+
eps=1e-6,
|
420 |
+
pre_only=True,
|
421 |
+
)
|
422 |
+
|
423 |
+
def forward(
|
424 |
+
self,
|
425 |
+
hidden_states: torch.FloatTensor,
|
426 |
+
temb: torch.FloatTensor,
|
427 |
+
image_rotary_emb=None,
|
428 |
+
attention_mask: Optional[torch.Tensor] = None,
|
429 |
+
):
|
430 |
+
residual = hidden_states
|
431 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
432 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
433 |
+
|
434 |
+
if attention_mask is not None:
|
435 |
+
attention_mask = expand_flux_attention_mask(
|
436 |
+
hidden_states,
|
437 |
+
attention_mask,
|
438 |
+
)
|
439 |
+
|
440 |
+
attn_output = self.attn(
|
441 |
+
hidden_states=norm_hidden_states,
|
442 |
+
image_rotary_emb=image_rotary_emb,
|
443 |
+
attention_mask=attention_mask,
|
444 |
+
)
|
445 |
+
|
446 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
447 |
+
gate = gate.unsqueeze(1)
|
448 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
449 |
+
hidden_states = residual + hidden_states
|
450 |
+
|
451 |
+
return hidden_states
|
452 |
+
|
453 |
+
|
454 |
+
@maybe_allow_in_graph
|
455 |
+
class FluxTransformerBlock(nn.Module):
|
456 |
+
r"""
|
457 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
458 |
+
|
459 |
+
Reference: https://arxiv.org/abs/2403.03206
|
460 |
+
|
461 |
+
Parameters:
|
462 |
+
dim (`int`): The number of channels in the input and output.
|
463 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
464 |
+
attention_head_dim (`int`): The number of channels in each head.
|
465 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
466 |
+
processing of `context` conditions.
|
467 |
+
"""
|
468 |
+
|
469 |
+
def __init__(
|
470 |
+
self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
|
471 |
+
):
|
472 |
+
super().__init__()
|
473 |
+
|
474 |
+
self.norm1 = AdaLayerNormZero(dim)
|
475 |
+
|
476 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
477 |
+
|
478 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
479 |
+
processor = FluxAttnProcessor2_0()
|
480 |
+
else:
|
481 |
+
raise ValueError(
|
482 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
483 |
+
)
|
484 |
+
self.attn = Attention(
|
485 |
+
query_dim=dim,
|
486 |
+
cross_attention_dim=None,
|
487 |
+
added_kv_proj_dim=dim,
|
488 |
+
dim_head=attention_head_dim,
|
489 |
+
heads=num_attention_heads,
|
490 |
+
out_dim=dim,
|
491 |
+
context_pre_only=False,
|
492 |
+
bias=True,
|
493 |
+
processor=processor,
|
494 |
+
qk_norm=qk_norm,
|
495 |
+
eps=eps,
|
496 |
+
)
|
497 |
+
|
498 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
499 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
500 |
+
|
501 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
502 |
+
self.ff_context = FeedForward(
|
503 |
+
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
|
504 |
+
)
|
505 |
+
|
506 |
+
# let chunk size default to None
|
507 |
+
self._chunk_size = None
|
508 |
+
self._chunk_dim = 0
|
509 |
+
|
510 |
+
def forward(
|
511 |
+
self,
|
512 |
+
hidden_states: torch.FloatTensor,
|
513 |
+
encoder_hidden_states: torch.FloatTensor,
|
514 |
+
temb: torch.FloatTensor,
|
515 |
+
image_rotary_emb=None,
|
516 |
+
attention_mask: Optional[torch.Tensor] = None,
|
517 |
+
):
|
518 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
519 |
+
hidden_states, emb=temb
|
520 |
+
)
|
521 |
+
|
522 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
523 |
+
self.norm1_context(encoder_hidden_states, emb=temb)
|
524 |
+
)
|
525 |
+
|
526 |
+
if attention_mask is not None:
|
527 |
+
attention_mask = expand_flux_attention_mask(
|
528 |
+
torch.cat([encoder_hidden_states, hidden_states], dim=1),
|
529 |
+
attention_mask,
|
530 |
+
)
|
531 |
+
|
532 |
+
# Attention.
|
533 |
+
attn_output, context_attn_output = self.attn(
|
534 |
+
hidden_states=norm_hidden_states,
|
535 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
536 |
+
image_rotary_emb=image_rotary_emb,
|
537 |
+
attention_mask=attention_mask,
|
538 |
+
)
|
539 |
+
|
540 |
+
# Process attention outputs for the `hidden_states`.
|
541 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
542 |
+
hidden_states = hidden_states + attn_output
|
543 |
+
|
544 |
+
norm_hidden_states = self.norm2(hidden_states)
|
545 |
+
norm_hidden_states = (
|
546 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
547 |
+
)
|
548 |
+
|
549 |
+
ff_output = self.ff(norm_hidden_states)
|
550 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
551 |
+
|
552 |
+
hidden_states = hidden_states + ff_output
|
553 |
+
|
554 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
555 |
+
|
556 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
557 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
558 |
+
|
559 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
560 |
+
norm_encoder_hidden_states = (
|
561 |
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
|
562 |
+
+ c_shift_mlp[:, None]
|
563 |
+
)
|
564 |
+
|
565 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
566 |
+
encoder_hidden_states = (
|
567 |
+
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
568 |
+
)
|
569 |
+
|
570 |
+
return encoder_hidden_states, hidden_states
|
571 |
+
|
572 |
+
|
573 |
+
class FluxTransformer2DModelWithMasking(
|
574 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
|
575 |
+
):
|
576 |
+
"""
|
577 |
+
The Transformer model introduced in Flux.
|
578 |
+
|
579 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
580 |
+
|
581 |
+
Parameters:
|
582 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
583 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
584 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
585 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
586 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
587 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
588 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
589 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
590 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
591 |
+
"""
|
592 |
+
|
593 |
+
_supports_gradient_checkpointing = True
|
594 |
+
|
595 |
+
@register_to_config
|
596 |
+
def __init__(
|
597 |
+
self,
|
598 |
+
patch_size: int = 1,
|
599 |
+
in_channels: int = 64,
|
600 |
+
num_layers: int = 19,
|
601 |
+
num_single_layers: int = 38,
|
602 |
+
attention_head_dim: int = 128,
|
603 |
+
num_attention_heads: int = 24,
|
604 |
+
joint_attention_dim: int = 4096,
|
605 |
+
pooled_projection_dim: int = 768,
|
606 |
+
guidance_embeds: bool = False,
|
607 |
+
axes_dims_rope: List[int] = [16, 56, 56],
|
608 |
+
):
|
609 |
+
super().__init__()
|
610 |
+
self.out_channels = in_channels
|
611 |
+
self.inner_dim = (
|
612 |
+
self.config.num_attention_heads * self.config.attention_head_dim
|
613 |
+
)
|
614 |
+
|
615 |
+
self.pos_embed = EmbedND(
|
616 |
+
dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
|
617 |
+
)
|
618 |
+
text_time_guidance_cls = (
|
619 |
+
CombinedTimestepGuidanceTextProjEmbeddings
|
620 |
+
if guidance_embeds
|
621 |
+
else CombinedTimestepTextProjEmbeddings
|
622 |
+
)
|
623 |
+
self.time_text_embed = text_time_guidance_cls(
|
624 |
+
embedding_dim=self.inner_dim,
|
625 |
+
pooled_projection_dim=self.config.pooled_projection_dim,
|
626 |
+
)
|
627 |
+
|
628 |
+
self.context_embedder = nn.Linear(
|
629 |
+
self.config.joint_attention_dim, self.inner_dim
|
630 |
+
)
|
631 |
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
632 |
+
|
633 |
+
self.transformer_blocks = nn.ModuleList(
|
634 |
+
[
|
635 |
+
FluxTransformerBlock(
|
636 |
+
dim=self.inner_dim,
|
637 |
+
num_attention_heads=self.config.num_attention_heads,
|
638 |
+
attention_head_dim=self.config.attention_head_dim,
|
639 |
+
)
|
640 |
+
for i in range(self.config.num_layers)
|
641 |
+
]
|
642 |
+
)
|
643 |
+
|
644 |
+
self.single_transformer_blocks = nn.ModuleList(
|
645 |
+
[
|
646 |
+
FluxSingleTransformerBlock(
|
647 |
+
dim=self.inner_dim,
|
648 |
+
num_attention_heads=self.config.num_attention_heads,
|
649 |
+
attention_head_dim=self.config.attention_head_dim,
|
650 |
+
)
|
651 |
+
for i in range(self.config.num_single_layers)
|
652 |
+
]
|
653 |
+
)
|
654 |
+
|
655 |
+
self.norm_out = AdaLayerNormContinuous(
|
656 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
|
657 |
+
)
|
658 |
+
self.proj_out = nn.Linear(
|
659 |
+
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
|
660 |
+
)
|
661 |
+
|
662 |
+
self.gradient_checkpointing = False
|
663 |
+
|
664 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
665 |
+
if hasattr(module, "gradient_checkpointing"):
|
666 |
+
module.gradient_checkpointing = value
|
667 |
+
|
668 |
+
def forward(
|
669 |
+
self,
|
670 |
+
hidden_states: torch.Tensor,
|
671 |
+
encoder_hidden_states: torch.Tensor = None,
|
672 |
+
pooled_projections: torch.Tensor = None,
|
673 |
+
timestep: torch.LongTensor = None,
|
674 |
+
img_ids: torch.Tensor = None,
|
675 |
+
txt_ids: torch.Tensor = None,
|
676 |
+
guidance: torch.Tensor = None,
|
677 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
678 |
+
return_dict: bool = True,
|
679 |
+
attention_mask: Optional[torch.Tensor] = None,
|
680 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
681 |
+
"""
|
682 |
+
The [`FluxTransformer2DModelWithMasking`] forward method.
|
683 |
+
|
684 |
+
Args:
|
685 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
686 |
+
Input `hidden_states`.
|
687 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
688 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
689 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
690 |
+
from the embeddings of input conditions.
|
691 |
+
timestep ( `torch.LongTensor`):
|
692 |
+
Used to indicate denoising step.
|
693 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
694 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
695 |
+
joint_attention_kwargs (`dict`, *optional*):
|
696 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
697 |
+
`self.processor` in
|
698 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
699 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
700 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
701 |
+
tuple.
|
702 |
+
|
703 |
+
Returns:
|
704 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
705 |
+
`tuple` where the first element is the sample tensor.
|
706 |
+
"""
|
707 |
+
if joint_attention_kwargs is not None:
|
708 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
709 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
710 |
+
else:
|
711 |
+
lora_scale = 1.0
|
712 |
+
|
713 |
+
if USE_PEFT_BACKEND:
|
714 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
715 |
+
scale_lora_layers(self, lora_scale)
|
716 |
+
else:
|
717 |
+
if (
|
718 |
+
joint_attention_kwargs is not None
|
719 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
720 |
+
):
|
721 |
+
logger.warning(
|
722 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
723 |
+
)
|
724 |
+
hidden_states = self.x_embedder(hidden_states)
|
725 |
+
|
726 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
727 |
+
if guidance is not None:
|
728 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
729 |
+
else:
|
730 |
+
guidance = None
|
731 |
+
temb = (
|
732 |
+
self.time_text_embed(timestep, pooled_projections)
|
733 |
+
if guidance is None
|
734 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
735 |
+
)
|
736 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
737 |
+
|
738 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
739 |
+
image_rotary_emb = self.pos_embed(ids)
|
740 |
+
|
741 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
742 |
+
if self.training and self.gradient_checkpointing:
|
743 |
+
|
744 |
+
def create_custom_forward(module, return_dict=None):
|
745 |
+
def custom_forward(*inputs):
|
746 |
+
if return_dict is not None:
|
747 |
+
return module(*inputs, return_dict=return_dict)
|
748 |
+
else:
|
749 |
+
return module(*inputs)
|
750 |
+
|
751 |
+
return custom_forward
|
752 |
+
|
753 |
+
ckpt_kwargs: Dict[str, Any] = (
|
754 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
755 |
+
)
|
756 |
+
encoder_hidden_states, hidden_states = (
|
757 |
+
torch.utils.checkpoint.checkpoint(
|
758 |
+
create_custom_forward(block),
|
759 |
+
hidden_states,
|
760 |
+
encoder_hidden_states,
|
761 |
+
temb,
|
762 |
+
image_rotary_emb,
|
763 |
+
attention_mask,
|
764 |
+
**ckpt_kwargs,
|
765 |
+
)
|
766 |
+
)
|
767 |
+
|
768 |
+
else:
|
769 |
+
encoder_hidden_states, hidden_states = block(
|
770 |
+
hidden_states=hidden_states,
|
771 |
+
encoder_hidden_states=encoder_hidden_states,
|
772 |
+
temb=temb,
|
773 |
+
image_rotary_emb=image_rotary_emb,
|
774 |
+
attention_mask=attention_mask,
|
775 |
+
)
|
776 |
+
|
777 |
+
# Flux places the text tokens in front of the image tokens in the
|
778 |
+
# sequence.
|
779 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
780 |
+
|
781 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
782 |
+
if self.training and self.gradient_checkpointing:
|
783 |
+
|
784 |
+
def create_custom_forward(module, return_dict=None):
|
785 |
+
def custom_forward(*inputs):
|
786 |
+
if return_dict is not None:
|
787 |
+
return module(*inputs, return_dict=return_dict)
|
788 |
+
else:
|
789 |
+
return module(*inputs)
|
790 |
+
|
791 |
+
return custom_forward
|
792 |
+
|
793 |
+
ckpt_kwargs: Dict[str, Any] = (
|
794 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
795 |
+
)
|
796 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
797 |
+
create_custom_forward(block),
|
798 |
+
hidden_states,
|
799 |
+
temb,
|
800 |
+
image_rotary_emb,
|
801 |
+
attention_mask,
|
802 |
+
**ckpt_kwargs,
|
803 |
+
)
|
804 |
+
|
805 |
+
else:
|
806 |
+
hidden_states = block(
|
807 |
+
hidden_states=hidden_states,
|
808 |
+
temb=temb,
|
809 |
+
image_rotary_emb=image_rotary_emb,
|
810 |
+
attention_mask=attention_mask,
|
811 |
+
)
|
812 |
+
|
813 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
814 |
+
|
815 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
816 |
+
output = self.proj_out(hidden_states)
|
817 |
+
|
818 |
+
if USE_PEFT_BACKEND:
|
819 |
+
# remove `lora_scale` from each PEFT layer
|
820 |
+
unscale_lora_layers(self, lora_scale)
|
821 |
+
|
822 |
+
if not return_dict:
|
823 |
+
return (output,)
|
824 |
+
|
825 |
+
return Transformer2DModelOutput(sample=output)
|
826 |
+
|
827 |
+
|
828 |
+
if __name__ == "__main__":
|
829 |
+
dtype = torch.bfloat16
|
830 |
+
bsz = 2
|
831 |
+
img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
|
832 |
+
timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
|
833 |
+
pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
|
834 |
+
text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
|
835 |
+
attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
|
836 |
+
"cuda", dtype=dtype
|
837 |
+
) # Last 128 positions are masked
|
838 |
+
|
839 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
840 |
+
latents = latents.view(
|
841 |
+
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
842 |
+
)
|
843 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
844 |
+
latents = latents.reshape(
|
845 |
+
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
846 |
+
)
|
847 |
+
|
848 |
+
return latents
|
849 |
+
|
850 |
+
def _prepare_latent_image_ids(
|
851 |
+
batch_size, height, width, device="cuda", dtype=dtype
|
852 |
+
):
|
853 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
854 |
+
latent_image_ids[..., 1] = (
|
855 |
+
latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
856 |
+
)
|
857 |
+
latent_image_ids[..., 2] = (
|
858 |
+
latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
859 |
+
)
|
860 |
+
|
861 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
|
862 |
+
latent_image_ids.shape
|
863 |
+
)
|
864 |
+
|
865 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
866 |
+
latent_image_ids = latent_image_ids.reshape(
|
867 |
+
batch_size,
|
868 |
+
latent_image_id_height * latent_image_id_width,
|
869 |
+
latent_image_id_channels,
|
870 |
+
)
|
871 |
+
|
872 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
873 |
+
|
874 |
+
txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
|
875 |
+
|
876 |
+
vae_scale_factor = 16
|
877 |
+
height = 2 * (int(512) // vae_scale_factor)
|
878 |
+
width = 2 * (int(512) // vae_scale_factor)
|
879 |
+
img_ids = _prepare_latent_image_ids(bsz, height, width)
|
880 |
+
img = _pack_latents(img, img.shape[0], 16, height, width)
|
881 |
+
|
882 |
+
# Gotta go fast
|
883 |
+
transformer = FluxTransformer2DModelWithMasking.from_config(
|
884 |
+
{
|
885 |
+
"attention_head_dim": 128,
|
886 |
+
"guidance_embeds": True,
|
887 |
+
"in_channels": 64,
|
888 |
+
"joint_attention_dim": 4096,
|
889 |
+
"num_attention_heads": 24,
|
890 |
+
"num_layers": 4,
|
891 |
+
"num_single_layers": 8,
|
892 |
+
"patch_size": 1,
|
893 |
+
"pooled_projection_dim": 768,
|
894 |
+
}
|
895 |
+
).to("cuda", dtype=dtype)
|
896 |
+
|
897 |
+
guidance = torch.tensor([2.0], device="cuda")
|
898 |
+
guidance = guidance.expand(bsz)
|
899 |
+
|
900 |
+
with torch.no_grad():
|
901 |
+
no_mask = transformer(
|
902 |
+
img,
|
903 |
+
encoder_hidden_states=text,
|
904 |
+
pooled_projections=pooled,
|
905 |
+
timestep=timestep,
|
906 |
+
img_ids=img_ids,
|
907 |
+
txt_ids=txt_ids,
|
908 |
+
guidance=guidance,
|
909 |
+
)
|
910 |
+
mask = transformer(
|
911 |
+
img,
|
912 |
+
encoder_hidden_states=text,
|
913 |
+
pooled_projections=pooled,
|
914 |
+
timestep=timestep,
|
915 |
+
img_ids=img_ids,
|
916 |
+
txt_ids=txt_ids,
|
917 |
+
guidance=guidance,
|
918 |
+
attention_mask=attn_mask,
|
919 |
+
)
|
920 |
+
|
921 |
+
assert torch.allclose(no_mask.sample, mask.sample) is False
|
922 |
+
print("Attention masking test ran OK. Differences in output were detected.")
|
923 |
+
|
924 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
925 |
+
|
926 |
+
EXAMPLE_DOC_STRING = """
|
927 |
+
Examples:
|
928 |
+
```py
|
929 |
+
>>> import torch
|
930 |
+
>>> from diffusers import FluxPipeline
|
931 |
+
|
932 |
+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
933 |
+
>>> pipe.to("cuda")
|
934 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
935 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
936 |
+
>>> # Refer to the pipeline documentation for more details.
|
937 |
+
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
|
938 |
+
>>> image.save("flux.png")
|
939 |
+
```
|
940 |
+
"""
|
941 |
+
|
942 |
+
|
943 |
+
def calculate_shift(
|
944 |
+
image_seq_len,
|
945 |
+
base_seq_len: int = 256,
|
946 |
+
max_seq_len: int = 4096,
|
947 |
+
base_shift: float = 0.5,
|
948 |
+
max_shift: float = 1.16,
|
949 |
+
):
|
950 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
951 |
+
b = base_shift - m * base_seq_len
|
952 |
+
mu = image_seq_len * m + b
|
953 |
+
return mu
|
954 |
+
|
955 |
+
|
956 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
957 |
+
def retrieve_timesteps(
|
958 |
+
scheduler,
|
959 |
+
num_inference_steps: Optional[int] = None,
|
960 |
+
device: Optional[Union[str, torch.device]] = None,
|
961 |
+
timesteps: Optional[List[int]] = None,
|
962 |
+
sigmas: Optional[List[float]] = None,
|
963 |
+
**kwargs,
|
964 |
+
):
|
965 |
+
"""
|
966 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
967 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
968 |
+
|
969 |
+
Args:
|
970 |
+
scheduler (`SchedulerMixin`):
|
971 |
+
The scheduler to get timesteps from.
|
972 |
+
num_inference_steps (`int`):
|
973 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
974 |
+
must be `None`.
|
975 |
+
device (`str` or `torch.device`, *optional*):
|
976 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
977 |
+
timesteps (`List[int]`, *optional*):
|
978 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
979 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
980 |
+
sigmas (`List[float]`, *optional*):
|
981 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
982 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
983 |
+
|
984 |
+
Returns:
|
985 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
986 |
+
second element is the number of inference steps.
|
987 |
+
"""
|
988 |
+
if timesteps is not None and sigmas is not None:
|
989 |
+
raise ValueError(
|
990 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
991 |
+
)
|
992 |
+
if timesteps is not None:
|
993 |
+
accepts_timesteps = "timesteps" in set(
|
994 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
995 |
+
)
|
996 |
+
if not accepts_timesteps:
|
997 |
+
raise ValueError(
|
998 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
999 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
1000 |
+
)
|
1001 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
1002 |
+
timesteps = scheduler.timesteps
|
1003 |
+
num_inference_steps = len(timesteps)
|
1004 |
+
elif sigmas is not None:
|
1005 |
+
accept_sigmas = "sigmas" in set(
|
1006 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
1007 |
+
)
|
1008 |
+
if not accept_sigmas:
|
1009 |
+
raise ValueError(
|
1010 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
1011 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
1012 |
+
)
|
1013 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
1014 |
+
timesteps = scheduler.timesteps
|
1015 |
+
num_inference_steps = len(timesteps)
|
1016 |
+
else:
|
1017 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
1018 |
+
timesteps = scheduler.timesteps
|
1019 |
+
return timesteps, num_inference_steps
|
1020 |
+
|
1021 |
+
|
1022 |
+
class CustomPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
1023 |
+
r"""
|
1024 |
+
The Flux pipeline for text-to-image generation.
|
1025 |
+
|
1026 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
1027 |
+
|
1028 |
+
Args:
|
1029 |
+
transformer ([`FluxTransformer2DModelWithMasking`]):
|
1030 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
1031 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
1032 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
1033 |
+
vae ([`AutoencoderKL`]):
|
1034 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
1035 |
+
text_encoder ([`CLIPTextModelWithProjection`]):
|
1036 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
1037 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
1038 |
+
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
1039 |
+
as its dimension.
|
1040 |
+
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
1041 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
1042 |
+
specifically the
|
1043 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
1044 |
+
variant.
|
1045 |
+
tokenizer (`CLIPTokenizer`):
|
1046 |
+
Tokenizer of class
|
1047 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
1048 |
+
tokenizer_2 (`CLIPTokenizer`):
|
1049 |
+
Second Tokenizer of class
|
1050 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
1051 |
+
"""
|
1052 |
+
|
1053 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
1054 |
+
_optional_components = []
|
1055 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
1056 |
+
|
1057 |
+
def __init__(
|
1058 |
+
self,
|
1059 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
1060 |
+
vae: AutoencoderKL,
|
1061 |
+
text_encoder: CLIPTextModel,
|
1062 |
+
tokenizer: CLIPTokenizer,
|
1063 |
+
text_encoder_2: T5EncoderModel,
|
1064 |
+
tokenizer_2: T5TokenizerFast,
|
1065 |
+
transformer: FluxTransformer2DModelWithMasking,
|
1066 |
+
):
|
1067 |
+
super().__init__()
|
1068 |
+
|
1069 |
+
self.register_modules(
|
1070 |
+
vae=vae,
|
1071 |
+
text_encoder=text_encoder,
|
1072 |
+
text_encoder_2=text_encoder_2,
|
1073 |
+
tokenizer=tokenizer,
|
1074 |
+
tokenizer_2=tokenizer_2,
|
1075 |
+
transformer=transformer,
|
1076 |
+
scheduler=scheduler,
|
1077 |
+
)
|
1078 |
+
self.vae_scale_factor = (
|
1079 |
+
2 ** (len(self.vae.config.block_out_channels))
|
1080 |
+
if hasattr(self, "vae") and self.vae is not None
|
1081 |
+
else 16
|
1082 |
+
)
|
1083 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
1084 |
+
self.tokenizer_max_length = (
|
1085 |
+
self.tokenizer.model_max_length
|
1086 |
+
if hasattr(self, "tokenizer") and self.tokenizer is not None
|
1087 |
+
else 77
|
1088 |
+
)
|
1089 |
+
self.default_sample_size = 64
|
1090 |
+
|
1091 |
+
def _get_t5_prompt_embeds(
|
1092 |
+
self,
|
1093 |
+
prompt: Union[str, List[str]] = None,
|
1094 |
+
num_images_per_prompt: int = 1,
|
1095 |
+
max_sequence_length: int = 512,
|
1096 |
+
device: Optional[torch.device] = None,
|
1097 |
+
dtype: Optional[torch.dtype] = None,
|
1098 |
+
):
|
1099 |
+
device = device or self._execution_device
|
1100 |
+
dtype = dtype or self.text_encoder.dtype
|
1101 |
+
|
1102 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
1103 |
+
batch_size = len(prompt)
|
1104 |
+
|
1105 |
+
text_inputs = self.tokenizer_2(
|
1106 |
+
prompt,
|
1107 |
+
padding="max_length",
|
1108 |
+
max_length=max_sequence_length,
|
1109 |
+
truncation=True,
|
1110 |
+
return_length=False,
|
1111 |
+
return_overflowing_tokens=False,
|
1112 |
+
return_tensors="pt",
|
1113 |
+
)
|
1114 |
+
prompt_attention_mask = text_inputs.attention_mask
|
1115 |
+
text_input_ids = text_inputs.input_ids
|
1116 |
+
untruncated_ids = self.tokenizer_2(
|
1117 |
+
prompt, padding="longest", return_tensors="pt"
|
1118 |
+
).input_ids
|
1119 |
+
|
1120 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
1121 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
1122 |
+
logger.warning(
|
1123 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
1124 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
1125 |
+
)
|
1126 |
+
|
1127 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
1128 |
+
|
1129 |
+
dtype = self.text_encoder_2.dtype
|
1130 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
1131 |
+
|
1132 |
+
_, seq_len, _ = prompt_embeds.shape
|
1133 |
+
|
1134 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
1135 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
1136 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
1137 |
+
|
1138 |
+
return prompt_embeds, prompt_attention_mask
|
1139 |
+
|
1140 |
+
def _get_clip_prompt_embeds(
|
1141 |
+
self,
|
1142 |
+
prompt: Union[str, List[str]],
|
1143 |
+
num_images_per_prompt: int = 1,
|
1144 |
+
device: Optional[torch.device] = None,
|
1145 |
+
):
|
1146 |
+
device = device or self._execution_device
|
1147 |
+
|
1148 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
1149 |
+
batch_size = len(prompt)
|
1150 |
+
|
1151 |
+
text_inputs = self.tokenizer(
|
1152 |
+
prompt,
|
1153 |
+
padding="max_length",
|
1154 |
+
max_length=self.tokenizer_max_length,
|
1155 |
+
truncation=True,
|
1156 |
+
return_overflowing_tokens=False,
|
1157 |
+
return_length=False,
|
1158 |
+
return_tensors="pt",
|
1159 |
+
)
|
1160 |
+
|
1161 |
+
text_input_ids = text_inputs.input_ids
|
1162 |
+
untruncated_ids = self.tokenizer(
|
1163 |
+
prompt, padding="longest", return_tensors="pt"
|
1164 |
+
).input_ids
|
1165 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
1166 |
+
text_input_ids, untruncated_ids
|
1167 |
+
):
|
1168 |
+
removed_text = self.tokenizer.batch_decode(
|
1169 |
+
untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
|
1170 |
+
)
|
1171 |
+
logger.warning(
|
1172 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
1173 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
1174 |
+
)
|
1175 |
+
prompt_embeds = self.text_encoder(
|
1176 |
+
text_input_ids.to(device), output_hidden_states=False
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
# Use pooled output of CLIPTextModel
|
1180 |
+
prompt_embeds = prompt_embeds.pooler_output
|
1181 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
1182 |
+
|
1183 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
1184 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
1185 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
1186 |
+
|
1187 |
+
return prompt_embeds
|
1188 |
+
|
1189 |
+
@lru_cache(maxsize=128)
|
1190 |
+
def encode_prompt(
|
1191 |
+
self,
|
1192 |
+
prompt: Union[str, List[str]],
|
1193 |
+
prompt_2: Union[str, List[str]],
|
1194 |
+
device: Optional[torch.device] = None,
|
1195 |
+
num_images_per_prompt: int = 1,
|
1196 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1197 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1198 |
+
max_sequence_length: int = 512,
|
1199 |
+
lora_scale: Optional[float] = None,
|
1200 |
+
):
|
1201 |
+
r"""
|
1202 |
+
|
1203 |
+
Args:
|
1204 |
+
prompt (`str` or `List[str]`, *optional*):
|
1205 |
+
prompt to be encoded
|
1206 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
1207 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
1208 |
+
used in all text-encoders
|
1209 |
+
device: (`torch.device`):
|
1210 |
+
torch device
|
1211 |
+
num_images_per_prompt (`int`):
|
1212 |
+
number of images that should be generated per prompt
|
1213 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1214 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1215 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1216 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1217 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
1218 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
1219 |
+
clip_skip (`int`, *optional*):
|
1220 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
1221 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
1222 |
+
lora_scale (`float`, *optional*):
|
1223 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
1224 |
+
"""
|
1225 |
+
device = device or self._execution_device
|
1226 |
+
|
1227 |
+
# set lora scale so that monkey patched LoRA
|
1228 |
+
# function of text encoder can correctly access it
|
1229 |
+
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
1230 |
+
self._lora_scale = lora_scale
|
1231 |
+
|
1232 |
+
# dynamically adjust the LoRA scale
|
1233 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
1234 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
1235 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
1236 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
1237 |
+
|
1238 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
1239 |
+
if prompt is not None:
|
1240 |
+
batch_size = len(prompt)
|
1241 |
+
else:
|
1242 |
+
batch_size = prompt_embeds.shape[0]
|
1243 |
+
|
1244 |
+
prompt_attention_mask = None
|
1245 |
+
if prompt_embeds is None:
|
1246 |
+
prompt_2 = prompt_2 or prompt
|
1247 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
1248 |
+
|
1249 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
1250 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
1251 |
+
prompt=prompt,
|
1252 |
+
device=device,
|
1253 |
+
num_images_per_prompt=num_images_per_prompt,
|
1254 |
+
)
|
1255 |
+
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
1256 |
+
prompt=prompt_2,
|
1257 |
+
num_images_per_prompt=num_images_per_prompt,
|
1258 |
+
max_sequence_length=max_sequence_length,
|
1259 |
+
device=device,
|
1260 |
+
)
|
1261 |
+
|
1262 |
+
if self.text_encoder is not None:
|
1263 |
+
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
1264 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
1265 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
1266 |
+
|
1267 |
+
if self.text_encoder_2 is not None:
|
1268 |
+
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
1269 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
1270 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
1271 |
+
|
1272 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
1273 |
+
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
1274 |
+
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
1275 |
+
|
1276 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
|
1277 |
+
|
1278 |
+
def check_inputs(
|
1279 |
+
self,
|
1280 |
+
prompt,
|
1281 |
+
prompt_2,
|
1282 |
+
height,
|
1283 |
+
width,
|
1284 |
+
prompt_embeds=None,
|
1285 |
+
pooled_prompt_embeds=None,
|
1286 |
+
callback_on_step_end_tensor_inputs=None,
|
1287 |
+
max_sequence_length=None,
|
1288 |
+
):
|
1289 |
+
if height % 8 != 0 or width % 8 != 0:
|
1290 |
+
raise ValueError(
|
1291 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
1292 |
+
)
|
1293 |
+
|
1294 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
1295 |
+
k in self._callback_tensor_inputs
|
1296 |
+
for k in callback_on_step_end_tensor_inputs
|
1297 |
+
):
|
1298 |
+
raise ValueError(
|
1299 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
1300 |
+
)
|
1301 |
+
|
1302 |
+
if prompt is not None and prompt_embeds is not None:
|
1303 |
+
raise ValueError(
|
1304 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
1305 |
+
" only forward one of the two."
|
1306 |
+
)
|
1307 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
1308 |
+
raise ValueError(
|
1309 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
1310 |
+
" only forward one of the two."
|
1311 |
+
)
|
1312 |
+
elif prompt is None and prompt_embeds is None:
|
1313 |
+
raise ValueError(
|
1314 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
1315 |
+
)
|
1316 |
+
elif prompt is not None and (
|
1317 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
1318 |
+
):
|
1319 |
+
raise ValueError(
|
1320 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
1321 |
+
)
|
1322 |
+
elif prompt_2 is not None and (
|
1323 |
+
not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
|
1324 |
+
):
|
1325 |
+
raise ValueError(
|
1326 |
+
f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
|
1327 |
+
)
|
1328 |
+
|
1329 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
1330 |
+
raise ValueError(
|
1331 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
1332 |
+
)
|
1333 |
+
|
1334 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
1335 |
+
raise ValueError(
|
1336 |
+
f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
@staticmethod
|
1340 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
1341 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
1342 |
+
latent_image_ids[..., 1] = (
|
1343 |
+
latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
1344 |
+
)
|
1345 |
+
latent_image_ids[..., 2] = (
|
1346 |
+
latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
1347 |
+
)
|
1348 |
+
|
1349 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
|
1350 |
+
latent_image_ids.shape
|
1351 |
+
)
|
1352 |
+
|
1353 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
1354 |
+
latent_image_ids = latent_image_ids.reshape(
|
1355 |
+
batch_size,
|
1356 |
+
latent_image_id_height * latent_image_id_width,
|
1357 |
+
latent_image_id_channels,
|
1358 |
+
)
|
1359 |
+
|
1360 |
+
return latent_image_ids
|
1361 |
+
|
1362 |
+
@staticmethod
|
1363 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
1364 |
+
latents = latents.view(
|
1365 |
+
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
1366 |
+
)
|
1367 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
1368 |
+
latents = latents.reshape(
|
1369 |
+
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
1370 |
+
)
|
1371 |
+
|
1372 |
+
return latents
|
1373 |
+
|
1374 |
+
@staticmethod
|
1375 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
1376 |
+
batch_size, num_patches, channels = latents.shape
|
1377 |
+
|
1378 |
+
height = height // vae_scale_factor
|
1379 |
+
width = width // vae_scale_factor
|
1380 |
+
|
1381 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
1382 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
1383 |
+
|
1384 |
+
latents = latents.reshape(
|
1385 |
+
batch_size, channels // (2 * 2), height * 2, width * 2
|
1386 |
+
)
|
1387 |
+
|
1388 |
+
return latents
|
1389 |
+
|
1390 |
+
def prepare_latents(
|
1391 |
+
self,
|
1392 |
+
batch_size,
|
1393 |
+
num_channels_latents,
|
1394 |
+
height,
|
1395 |
+
width,
|
1396 |
+
dtype,
|
1397 |
+
device,
|
1398 |
+
generator,
|
1399 |
+
latents=None,
|
1400 |
+
):
|
1401 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
1402 |
+
width = 2 * (int(width) // self.vae_scale_factor)
|
1403 |
+
|
1404 |
+
shape = (batch_size, num_channels_latents, height, width)
|
1405 |
+
|
1406 |
+
if latents is not None:
|
1407 |
+
latent_image_ids = self._prepare_latent_image_ids(
|
1408 |
+
batch_size, height, width, device, dtype
|
1409 |
+
)
|
1410 |
+
return latents, latent_image_ids
|
1411 |
+
|
1412 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
1413 |
+
raise ValueError(
|
1414 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
1415 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
1416 |
+
)
|
1417 |
+
|
1418 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
1419 |
+
latents = self._pack_latents(
|
1420 |
+
latents, batch_size, num_channels_latents, height, width
|
1421 |
+
)
|
1422 |
+
|
1423 |
+
latent_image_ids = self._prepare_latent_image_ids(
|
1424 |
+
batch_size, height, width, device, dtype
|
1425 |
+
)
|
1426 |
+
|
1427 |
+
return latents, latent_image_ids
|
1428 |
+
|
1429 |
+
@property
|
1430 |
+
def guidance_scale(self):
|
1431 |
+
return self._guidance_scale
|
1432 |
+
|
1433 |
+
@property
|
1434 |
+
def joint_attention_kwargs(self):
|
1435 |
+
return self._joint_attention_kwargs
|
1436 |
+
|
1437 |
+
@property
|
1438 |
+
def num_timesteps(self):
|
1439 |
+
return self._num_timesteps
|
1440 |
+
|
1441 |
+
@property
|
1442 |
+
def interrupt(self):
|
1443 |
+
return self._interrupt
|
1444 |
+
|
1445 |
+
@torch.no_grad()
|
1446 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
1447 |
+
def __call__(
|
1448 |
+
self,
|
1449 |
+
prompt: Union[str, List[str]] = None,
|
1450 |
+
prompt_mask: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]] = None,
|
1451 |
+
negative_mask: Optional[
|
1452 |
+
Union[torch.FloatTensor, List[torch.FloatTensor]]
|
1453 |
+
] = None,
|
1454 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
1455 |
+
height: Optional[int] = None,
|
1456 |
+
width: Optional[int] = None,
|
1457 |
+
num_inference_steps: int = 28,
|
1458 |
+
timesteps: List[int] = None,
|
1459 |
+
guidance_scale: float = 3.5,
|
1460 |
+
num_images_per_prompt: Optional[int] = 1,
|
1461 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1462 |
+
latents: Optional[torch.FloatTensor] = None,
|
1463 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1464 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1465 |
+
output_type: Optional[str] = "pil",
|
1466 |
+
return_dict: bool = True,
|
1467 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1468 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
1469 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
1470 |
+
max_sequence_length: int = 512,
|
1471 |
+
guidance_scale_real: float = 1.0,
|
1472 |
+
negative_prompt: Union[str, List[str]] = "",
|
1473 |
+
negative_prompt_2: Union[str, List[str]] = "",
|
1474 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1475 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1476 |
+
no_cfg_until_timestep: int = 0,
|
1477 |
+
use_prompt_mask: bool = True,
|
1478 |
+
zero_using_prompt_mask: bool = False,
|
1479 |
+
device=torch.device('cuda'), # TODO let this work with non-cuda stuff? Might if you set this to None
|
1480 |
+
):
|
1481 |
+
r"""
|
1482 |
+
Function invoked when calling the pipeline for generation.
|
1483 |
+
|
1484 |
+
Args:
|
1485 |
+
prompt (`str` or `List[str]`, *optional*):
|
1486 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
1487 |
+
instead.
|
1488 |
+
prompt_mask (`str` or `List[str]`, *optional*):
|
1489 |
+
The prompt or prompts to be used as a mask for the image generation. If not defined, `prompt` is used
|
1490 |
+
instead.
|
1491 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
1492 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
1493 |
+
will be used instead
|
1494 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1495 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
1496 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
1497 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
1498 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1499 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1500 |
+
expense of slower inference.
|
1501 |
+
timesteps (`List[int]`, *optional*):
|
1502 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
1503 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
1504 |
+
passed will be used. Must be in descending order.
|
1505 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
1506 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1507 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1508 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1509 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1510 |
+
usually at the expense of lower image quality.
|
1511 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1512 |
+
The number of images to generate per prompt.
|
1513 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1514 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1515 |
+
to make generation deterministic.
|
1516 |
+
latents (`torch.FloatTensor`, *optional*):
|
1517 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1518 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1519 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1520 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1521 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1522 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1523 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1524 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
1525 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
1526 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1527 |
+
The output format of the generate image. Choose between
|
1528 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1529 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1530 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
1531 |
+
joint_attention_kwargs (`dict`, *optional*):
|
1532 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1533 |
+
`self.processor` in
|
1534 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1535 |
+
callback_on_step_end (`Callable`, *optional*):
|
1536 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
1537 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
1538 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
1539 |
+
`callback_on_step_end_tensor_inputs`.
|
1540 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
1541 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
1542 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
1543 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
1544 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
1545 |
+
|
1546 |
+
Examples:
|
1547 |
+
|
1548 |
+
Returns:
|
1549 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
1550 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
1551 |
+
images.
|
1552 |
+
"""
|
1553 |
+
|
1554 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
1555 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
1556 |
+
|
1557 |
+
# 1. Check inputs. Raise error if not correct
|
1558 |
+
self.check_inputs(
|
1559 |
+
prompt,
|
1560 |
+
prompt_2,
|
1561 |
+
height,
|
1562 |
+
width,
|
1563 |
+
prompt_embeds=prompt_embeds,
|
1564 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1565 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
1566 |
+
max_sequence_length=max_sequence_length,
|
1567 |
+
)
|
1568 |
+
|
1569 |
+
self._guidance_scale = guidance_scale
|
1570 |
+
self._guidance_scale_real = guidance_scale_real
|
1571 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
1572 |
+
self._interrupt = False
|
1573 |
+
|
1574 |
+
# 2. Define call parameters
|
1575 |
+
if prompt is not None and isinstance(prompt, str):
|
1576 |
+
batch_size = 1
|
1577 |
+
elif prompt is not None and isinstance(prompt, list):
|
1578 |
+
batch_size = len(prompt)
|
1579 |
+
else:
|
1580 |
+
batch_size = prompt_embeds.shape[0]
|
1581 |
+
|
1582 |
+
device = device or self._execution_device
|
1583 |
+
|
1584 |
+
lora_scale = (
|
1585 |
+
self.joint_attention_kwargs.get("scale", None)
|
1586 |
+
if self.joint_attention_kwargs is not None
|
1587 |
+
else None
|
1588 |
+
)
|
1589 |
+
(
|
1590 |
+
prompt_embeds,
|
1591 |
+
pooled_prompt_embeds,
|
1592 |
+
text_ids,
|
1593 |
+
_prompt_mask,
|
1594 |
+
) = self.encode_prompt(
|
1595 |
+
prompt=prompt,
|
1596 |
+
prompt_2=prompt_2,
|
1597 |
+
prompt_embeds=prompt_embeds,
|
1598 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1599 |
+
device=device,
|
1600 |
+
num_images_per_prompt=num_images_per_prompt,
|
1601 |
+
max_sequence_length=max_sequence_length,
|
1602 |
+
lora_scale=lora_scale,
|
1603 |
+
)
|
1604 |
+
if _prompt_mask is not None:
|
1605 |
+
prompt_mask = _prompt_mask
|
1606 |
+
|
1607 |
+
if negative_prompt_2 == "" and negative_prompt != "":
|
1608 |
+
negative_prompt_2 = negative_prompt
|
1609 |
+
|
1610 |
+
negative_text_ids = text_ids
|
1611 |
+
if self._guidance_scale_real > 1.0 and (
|
1612 |
+
negative_prompt_embeds is None or negative_pooled_prompt_embeds is None
|
1613 |
+
):
|
1614 |
+
(
|
1615 |
+
negative_prompt_embeds,
|
1616 |
+
negative_pooled_prompt_embeds,
|
1617 |
+
negative_text_ids,
|
1618 |
+
_neg_prompt_mask,
|
1619 |
+
) = self.encode_prompt(
|
1620 |
+
prompt=negative_prompt,
|
1621 |
+
prompt_2=negative_prompt_2,
|
1622 |
+
prompt_embeds=None,
|
1623 |
+
pooled_prompt_embeds=None,
|
1624 |
+
device=device,
|
1625 |
+
num_images_per_prompt=num_images_per_prompt,
|
1626 |
+
max_sequence_length=max_sequence_length,
|
1627 |
+
lora_scale=lora_scale,
|
1628 |
+
)
|
1629 |
+
|
1630 |
+
if _neg_prompt_mask is not None:
|
1631 |
+
negative_mask = _neg_prompt_mask
|
1632 |
+
|
1633 |
+
# 4. Prepare latent variables
|
1634 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
1635 |
+
latents, latent_image_ids = self.prepare_latents(
|
1636 |
+
batch_size * num_images_per_prompt,
|
1637 |
+
num_channels_latents,
|
1638 |
+
height,
|
1639 |
+
width,
|
1640 |
+
prompt_embeds.dtype,
|
1641 |
+
device,
|
1642 |
+
generator,
|
1643 |
+
latents,
|
1644 |
+
)
|
1645 |
+
|
1646 |
+
# 5. Prepare timesteps
|
1647 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
1648 |
+
image_seq_len = latents.shape[1]
|
1649 |
+
mu = calculate_shift(
|
1650 |
+
image_seq_len,
|
1651 |
+
self.scheduler.config.base_image_seq_len,
|
1652 |
+
self.scheduler.config.max_image_seq_len,
|
1653 |
+
self.scheduler.config.base_shift,
|
1654 |
+
self.scheduler.config.max_shift,
|
1655 |
+
)
|
1656 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
1657 |
+
self.scheduler,
|
1658 |
+
num_inference_steps,
|
1659 |
+
device,
|
1660 |
+
timesteps,
|
1661 |
+
sigmas,
|
1662 |
+
mu=mu,
|
1663 |
+
)
|
1664 |
+
num_warmup_steps = max(
|
1665 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
1666 |
+
)
|
1667 |
+
self._num_timesteps = len(timesteps)
|
1668 |
+
|
1669 |
+
latents = latents
|
1670 |
+
latent_image_ids = latent_image_ids
|
1671 |
+
timesteps = timesteps
|
1672 |
+
text_ids = text_ids.to(device=device)
|
1673 |
+
|
1674 |
+
# handle guidance
|
1675 |
+
if self.transformer.config.guidance_embeds:
|
1676 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
1677 |
+
guidance = guidance.expand(latents.shape[0])
|
1678 |
+
else:
|
1679 |
+
guidance = None
|
1680 |
+
|
1681 |
+
if use_prompt_mask and prompt_mask is not None and not zero_using_prompt_mask:
|
1682 |
+
print('Using masking')
|
1683 |
+
elif use_prompt_mask and prompt_mask is not None and zero_using_prompt_mask:
|
1684 |
+
print('Using zeroed embeds')
|
1685 |
+
else:
|
1686 |
+
print('Not using masking')
|
1687 |
+
|
1688 |
+
if self._guidance_scale_real > 1.0:
|
1689 |
+
print('Using classifier free guidance', self._guidance_scale_real)
|
1690 |
+
|
1691 |
+
# 6. Denoising loop
|
1692 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1693 |
+
for i, t in enumerate(timesteps):
|
1694 |
+
if self.interrupt:
|
1695 |
+
continue
|
1696 |
+
|
1697 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1698 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
1699 |
+
|
1700 |
+
assert prompt_mask is not None
|
1701 |
+
|
1702 |
+
extra_transformer_args = {}
|
1703 |
+
if use_prompt_mask and prompt_mask is not None and not zero_using_prompt_mask:
|
1704 |
+
extra_transformer_args["attention_mask"] = prompt_mask
|
1705 |
+
elif use_prompt_mask and prompt_mask is not None and zero_using_prompt_mask:
|
1706 |
+
mask_tens = prompt_mask.unsqueeze(-1).to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
1707 |
+
prompt_embeds = prompt_embeds * mask_tens
|
1708 |
+
|
1709 |
+
noise_pred = self.transformer(
|
1710 |
+
hidden_states=latents,
|
1711 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
1712 |
+
timestep=timestep / 1000,
|
1713 |
+
guidance=guidance,
|
1714 |
+
pooled_projections=pooled_prompt_embeds,
|
1715 |
+
encoder_hidden_states=prompt_embeds,
|
1716 |
+
txt_ids=text_ids,
|
1717 |
+
img_ids=latent_image_ids.to(device=device),
|
1718 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
1719 |
+
return_dict=False,
|
1720 |
+
**extra_transformer_args,
|
1721 |
+
)[0]
|
1722 |
+
|
1723 |
+
# TODO optionally use batch prediction to speed this up.
|
1724 |
+
if self._guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
|
1725 |
+
extra_transformer_args_neg = {}
|
1726 |
+
if negative_mask is not None:
|
1727 |
+
extra_transformer_args_neg["attention_mask"] = negative_mask
|
1728 |
+
extra_transformer_args_neg["attention_mask"] is not None
|
1729 |
+
|
1730 |
+
noise_pred_uncond = self.transformer(
|
1731 |
+
hidden_states=latents,
|
1732 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
1733 |
+
timestep=timestep / 1000,
|
1734 |
+
guidance=guidance,
|
1735 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
1736 |
+
encoder_hidden_states=negative_prompt_embeds,
|
1737 |
+
txt_ids=negative_text_ids,
|
1738 |
+
img_ids=latent_image_ids.to(device=device),
|
1739 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
1740 |
+
return_dict=False,
|
1741 |
+
**extra_transformer_args_neg,
|
1742 |
+
)[0]
|
1743 |
+
|
1744 |
+
noise_pred = noise_pred_uncond + self._guidance_scale_real * (
|
1745 |
+
noise_pred - noise_pred_uncond
|
1746 |
+
)
|
1747 |
+
progress_bar.set_postfix(
|
1748 |
+
{
|
1749 |
+
'ts': timestep.detach().item() / 1000,
|
1750 |
+
'cfg': self._guidance_scale_real,
|
1751 |
+
},
|
1752 |
+
)
|
1753 |
+
else:
|
1754 |
+
progress_bar.set_postfix(
|
1755 |
+
{
|
1756 |
+
'ts': timestep.detach().item() / 1000,
|
1757 |
+
'cfg': 'N/A',
|
1758 |
+
},
|
1759 |
+
)
|
1760 |
+
|
1761 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1762 |
+
latents_dtype = latents.dtype
|
1763 |
+
latents = self.scheduler.step(
|
1764 |
+
noise_pred, t, latents, return_dict=False
|
1765 |
+
)[0]
|
1766 |
+
|
1767 |
+
if latents.dtype != latents_dtype:
|
1768 |
+
if torch.backends.mps.is_available():
|
1769 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1770 |
+
latents = latents.to(latents_dtype)
|
1771 |
+
|
1772 |
+
if callback_on_step_end is not None:
|
1773 |
+
callback_kwargs = {}
|
1774 |
+
for k in callback_on_step_end_tensor_inputs:
|
1775 |
+
callback_kwargs[k] = locals()[k]
|
1776 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1777 |
+
|
1778 |
+
latents = callback_outputs.pop("latents", latents)
|
1779 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1780 |
+
|
1781 |
+
# call the callback, if provided
|
1782 |
+
if i == len(timesteps) - 1 or (
|
1783 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
1784 |
+
):
|
1785 |
+
progress_bar.update()
|
1786 |
+
|
1787 |
+
if XLA_AVAILABLE:
|
1788 |
+
xm.mark_step()
|
1789 |
+
|
1790 |
+
if output_type == "latent":
|
1791 |
+
image = latents
|
1792 |
+
|
1793 |
+
else:
|
1794 |
+
latents = self._unpack_latents(
|
1795 |
+
latents, height, width, self.vae_scale_factor
|
1796 |
+
)
|
1797 |
+
latents = (
|
1798 |
+
latents / self.vae.config.scaling_factor
|
1799 |
+
) + self.vae.config.shift_factor
|
1800 |
+
|
1801 |
+
image = self.vae.decode(
|
1802 |
+
latents,
|
1803 |
+
return_dict=False,
|
1804 |
+
)[0]
|
1805 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1806 |
+
|
1807 |
+
# Offload all models
|
1808 |
+
self.maybe_free_model_hooks()
|
1809 |
+
|
1810 |
+
if not return_dict:
|
1811 |
+
return (image,)
|
1812 |
+
|
1813 |
+
return FluxPipelineOutput(images=image)
|