Edit model card

black-forest-labs/FLUX.1-dev quantized the Transformer model to INT4 and the T5 Text Encoder to INT8 using Optimum Quanto with FP16 calculations.

pip install diffusers optimum-quanto
import json
import torch
import diffusers
import transformers
from optimum.quanto import requantize
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download


def load_quanto_transformer(repo_path):
    with open(hf_hub_download(repo_path, "transformer/quantization_map.json"), "r") as f:
        quantization_map = json.load(f)
    with torch.device("meta"):
        transformer = diffusers.FluxTransformer2DModel.from_config(hf_hub_download(repo_path, "transformer/config.json")).to(torch.float16)
    state_dict = load_file(hf_hub_download(repo_path, "transformer/diffusion_pytorch_model.safetensors"))
    requantize(transformer, state_dict, quantization_map, device=torch.device("cuda"))
    return transformer


def load_quanto_text_encoder_2(repo_path):
    with open(hf_hub_download(repo_path, "text_encoder_2/quantization_map.json"), "r") as f:
        quantization_map = json.load(f)
    with open(hf_hub_download(repo_path, "text_encoder_2/config.json")) as f:
        t5_config = transformers.T5Config(**json.load(f))
    with torch.device("meta"):
        text_encoder_2 = transformers.T5EncoderModel(t5_config).to(torch.float16)
    state_dict = load_file(hf_hub_download(repo_path, "text_encoder_2/model.safetensors"))
    requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cuda"))
    return text_encoder_2


pipe = diffusers.AutoPipelineForText2Image.from_pretrained("Disty0/FLUX.1-dev-qint4_tf-qint8_te", transformer=None, text_encoder_2=None, torch_dtype=torch.float16)
pipe.transformer = load_quanto_transformer("Disty0/FLUX.1-dev-qint4_tf-qint8_te")
pipe.text_encoder_2 = load_quanto_text_encoder_2("Disty0/FLUX.1-dev-qint4_tf-qint8_te")
pipe = pipe.to("cuda", dtype=torch.float16)


prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")
Downloads last month
9
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.