import logging from typing import Any, Optional import torch from diffusers import ( AutoencoderKL, DDPMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, Transformer2DModel, UNet2DConditionModel, ) from huggingface_hub import hf_hub_download from safetensors.torch import load_file from models.RewardPixart import RewardPixartPipeline, freeze_params from models.RewardStableDiffusion import RewardStableDiffusion from models.RewardStableDiffusionXL import RewardStableDiffusionXL from models.RewardFlux import RewardFluxPipeline def get_model( model_name: str, dtype: torch.dtype, device: torch.device, cache_dir: str, memsave: bool = False, enable_sequential_cpu_offload: bool = False, ): logging.info(f"Loading model: {model_name}") if model_name == "sd-turbo": pipe = RewardStableDiffusion.from_pretrained( "stabilityai/sd-turbo", torch_dtype=dtype, variant="fp16", cache_dir=cache_dir, memsave=memsave, ) #pipe = pipe.to(device, dtype) elif model_name == "sdxl-turbo": vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir=cache_dir, ) pipe = RewardStableDiffusionXL.from_pretrained( "stabilityai/sdxl-turbo", vae=vae, torch_dtype=dtype, variant="fp16", use_safetensors=True, cache_dir=cache_dir, memsave=memsave, ) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing" ) #pipe = pipe.to(device, dtype) elif model_name == "pixart": pipe = RewardPixartPipeline.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=dtype, cache_dir=cache_dir, memsave=memsave, ) pipe.transformer = Transformer2DModel.from_pretrained( "PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="transformer", torch_dtype=dtype, cache_dir=cache_dir, ) pipe.scheduler = DDPMScheduler.from_pretrained( "PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="scheduler", cache_dir=cache_dir, ) # speed-up T5 pipe.text_encoder.to_bettertransformer() pipe.transformer.eval() freeze_params(pipe.transformer.parameters()) pipe.transformer.enable_gradient_checkpointing() #pipe = pipe.to(device) elif model_name == "hyper-sd": base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" repo_name = "ByteDance/Hyper-SD" ckpt_name = "Hyper-SDXL-1step-Unet.safetensors" # Load model but don't specify device or dtype (defaults to CPU and float32) unet = UNet2DConditionModel.from_config( base_model_id, subfolder="unet", cache_dir=cache_dir ) # Load state dict into unet (stays on CPU by default) unet.load_state_dict( load_file( hf_hub_download(repo_name, ckpt_name, cache_dir=cache_dir), device="cuda", ) ) # Initialize the pipeline (it will stay on CPU initially, using default dtype) pipe = RewardStableDiffusionXL.from_pretrained( base_model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", # Still set fp16 for later use on GPU cache_dir=cache_dir, is_hyper=True, memsave=memsave, ) # Use LCM scheduler instead of ddim scheduler to support specific timestep number inputs pipe.scheduler = LCMScheduler.from_config( pipe.scheduler.config, cache_dir=cache_dir ) elif model_name == "flux": pipe = RewardFluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16, cache_dir=cache_dir, ) #pipe.to(device, dtype) else: raise ValueError(f"Unknown model name: {model_name}") #if enable_sequential_cpu_offload: # pipe.enable_sequential_cpu_offload() return pipe def get_multi_apply_fn( model_type: str, seed: int, pipe: Optional[Any] = None, cache_dir: Optional[str] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): generator = torch.Generator("cuda").manual_seed(seed) if model_type == "flux": return lambda latents, prompt: torch.no_grad(pipe.apply)( latents=latents, prompt=prompt, num_inference_steps=4, generator=generator, ) elif model_type == "sdxl": vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir=cache_dir, ) pipe = RewardStableDiffusionXL.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", vae=vae, use_safetensors=True, cache_dir=cache_dir, ) pipe = pipe.to(device, dtype) pipe.enable_sequential_cpu_offload() return lambda latents, prompt: torch.no_grad(pipe.apply)( latents=latents, prompt=prompt, guidance_scale=5.0, num_inference_steps=50, generator=generator, ) elif model_type == "sd2": sd2_base = "stabilityai/stable-diffusion-2-1-base" scheduler = EulerDiscreteScheduler.from_pretrained( sd2_base, subfolder="scheduler", cache_dir=cache_dir, ) pipe = RewardStableDiffusion.from_pretrained( sd2_base, torch_dtype=dtype, cache_dir=cache_dir, scheduler=scheduler, ) pipe = pipe.to(device, dtype) pipe.enable_sequential_cpu_offload() return lambda latents, prompt: torch.no_grad(pipe.apply)( latents=latents, prompt=prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator, ) else: raise ValueError(f"Unknown model type: {model_type}")