Flux Latent Preview at Half-Size
The decoder provides a preview image; such thing already exists in the wild for the Flux Dev model.
Max supported resolution is between 768 and 1024px.
Retraining the text encoder and the VAE decoder has reduced the checkpoint size by around 10GB. This set the model's capabilities back by two years.
Inference
from diffusers import AutoencoderKL, FluxPipeline
from safetensors.torch import load_model
from tea_model import TeaDecoder
import torch
from torchvision import transforms
def preview_image(latents, pipe):
latents = FluxPipeline._unpack_latents(latents,
pipe.default_sample_size * pipe.vae_scale_factor,
pipe.default_sample_size * pipe.vae_scale_factor,
pipe.vae_scale_factor)
tea = TeaDecoder(ch_in=16)
load_model(tea, './vae_decoder.safetensors')
tea = tea.to(device='cuda')
output = tea(latents.to(torch.float32)) / 2.0 + 0.5
preview = transforms.ToPILImage()(output[0].clamp(0, 1))
return preview
def full_size_image(latents, pipe):
latents = FluxPipeline._unpack_latents(latents,
pipe.default_sample_size * pipe.vae_scale_factor,
pipe.default_sample_size * pipe.vae_scale_factor,
pipe.vae_scale_factor)
latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
latents = latents.to(dtype=pipe.vae.dtype)
torch.cuda.empty_cache()
pipe.vae = pipe.vae.to(device='cuda')
pixel_values, = pipe.vae.decode(latents, return_dict=False)
images = pipe.image_processor.postprocess(pixel_values.to('cpu'), output_type='pil')
return images
if __name__ == '__main__':
pipe = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev')
latents = pipe('cat playing piano', num_inference_steps=10, output_type='latent').images
# Return the upscaled and preview image.
upscaled = full_size_image(latents, pipe)
preview = preview_image(latents, pipe)
preview.save('cat.png')
Disclaimer
Use of this code and the copy of documentation requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.