Fix model loading
Browse files- quant_sdxl/quant_sdxl.py +4 -1
quant_sdxl/quant_sdxl.py
CHANGED
@@ -102,7 +102,10 @@ def main(args):
|
|
102 |
|
103 |
# Load model from float checkpoint
|
104 |
print(f"Loading model from {args.model}...")
|
105 |
-
|
|
|
|
|
|
|
106 |
print(f"Model loaded from {args.model}.")
|
107 |
|
108 |
# Move model to target device
|
|
|
102 |
|
103 |
# Load model from float checkpoint
|
104 |
print(f"Loading model from {args.model}...")
|
105 |
+
variant = 'fp16' if dtype == torch.float16 else None
|
106 |
+
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype, variant=variant, use_safetensors=True)
|
107 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
108 |
+
pipe.vae.config.force_upcast=True
|
109 |
print(f"Model loaded from {args.model}.")
|
110 |
|
111 |
# Move model to target device
|