Quantizing flux
Not a problem but a general observation:
I don't see how this approach works any better or faster than
directly quantizing the transformer and text_encoder_2
directly after loading the pipeline:
repo_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
from optimum.quanto import freeze, qint8, quantize
quantize(pipe.transformer, weights=qint8)
print("Running transformer freeze DEV")
freeze(pipe.transformer)
quantize(pipe.text_encoder_2, weights=qint8)
print("Running text_encoder_2 freeze DEV")
freeze(pipe.text_encoder_2)
What am I missing here?
Known issue: https://github.com/huggingface/optimum-quanto/issues/270
Disabling GEMM reduces the loading time from 250 seconds to 20 seconds.
You can disable GEMM if you want:
from optimum import quanto
quanto.tensor.qbits.QBitsTensor.create = lambda *args, **kwargs: quanto.tensor.qbits.QBitsTensor(*args, **kwargs)