|
from constants import LCM_DEFAULT_MODEL |
|
from diffusers import ( |
|
DiffusionPipeline, |
|
AutoencoderTiny, |
|
UNet2DConditionModel, |
|
LCMScheduler, |
|
) |
|
import torch |
|
from backend.tiny_decoder import get_tiny_decoder_vae_model |
|
from typing import Any |
|
from diffusers import ( |
|
LCMScheduler, |
|
StableDiffusionImg2ImgPipeline, |
|
StableDiffusionXLImg2ImgPipeline, |
|
) |
|
|
|
|
|
def _get_lcm_pipeline_from_base_model( |
|
lcm_model_id: str, |
|
base_model_id: str, |
|
use_local_model: bool, |
|
): |
|
pipeline = None |
|
unet = UNet2DConditionModel.from_pretrained( |
|
lcm_model_id, |
|
torch_dtype=torch.float32, |
|
local_files_only=use_local_model, |
|
) |
|
pipeline = DiffusionPipeline.from_pretrained( |
|
base_model_id, |
|
unet=unet, |
|
torch_dtype=torch.float32, |
|
local_files_only=use_local_model, |
|
) |
|
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) |
|
return pipeline |
|
|
|
|
|
def load_taesd( |
|
pipeline: Any, |
|
use_local_model: bool = False, |
|
torch_data_type: torch.dtype = torch.float32, |
|
): |
|
vae_model = get_tiny_decoder_vae_model(pipeline.__class__.__name__) |
|
pipeline.vae = AutoencoderTiny.from_pretrained( |
|
vae_model, |
|
torch_dtype=torch_data_type, |
|
local_files_only=use_local_model, |
|
) |
|
|
|
|
|
def get_lcm_model_pipeline( |
|
model_id: str = LCM_DEFAULT_MODEL, |
|
use_local_model: bool = False, |
|
): |
|
pipeline = None |
|
if model_id == "latent-consistency/lcm-sdxl": |
|
pipeline = _get_lcm_pipeline_from_base_model( |
|
model_id, |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
use_local_model, |
|
) |
|
|
|
elif model_id == "latent-consistency/lcm-ssd-1b": |
|
pipeline = _get_lcm_pipeline_from_base_model( |
|
model_id, |
|
"segmind/SSD-1B", |
|
use_local_model, |
|
) |
|
else: |
|
pipeline = DiffusionPipeline.from_pretrained( |
|
model_id, |
|
local_files_only=use_local_model, |
|
) |
|
|
|
return pipeline |
|
|
|
|
|
def get_image_to_image_pipeline(pipeline: Any) -> Any: |
|
components = pipeline.components |
|
pipeline_class = pipeline.__class__.__name__ |
|
if ( |
|
pipeline_class == "LatentConsistencyModelPipeline" |
|
or pipeline_class == "StableDiffusionPipeline" |
|
): |
|
return StableDiffusionImg2ImgPipeline(**components) |
|
elif pipeline_class == "StableDiffusionXLPipeline": |
|
return StableDiffusionXLImg2ImgPipeline(**components) |
|
else: |
|
raise Exception(f"Unknown pipeline {pipeline_class}") |
|
|