Run it on low VRAM GPUs
#1
by
sean-mediabox
- opened
Adapted from pixart, following script would allow running it on low VRAM GPUs:
# pip install -U accelerate transformers bitsandbytes
# pip install -U git+https://github.com/huggingface/diffusers
from transformers import T5EncoderModel
from diffusers import DiffusionPipeline
import torch
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
model_name="ptx0/pixart-900m-1024-ft"
# Loading in 8 bits needs `bitsandbytes`.
text_encoder = T5EncoderModel.from_pretrained(
model_name,
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",
)
pipe = DiffusionPipeline.from_pretrained(
model_name,
text_encoder=text_encoder,
transformer=None,
device_map="balanced"
)
with torch.no_grad():
prompt = "A landscape photograph of a small cottage in the middle of a field of wild flowers with mountains off in the distance at sunset"
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
del text_encoder
del pipe
flush()
pipe = DiffusionPipeline.from_pretrained(
model_name,
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")
latents = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
output_type="latent",
).images
del pipe.transformer
flush()
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")
del pipe
flush()
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)