from constants import ( | |
TAESD_MODEL, | |
TAESDXL_MODEL, | |
TAESD_MODEL_OPENVINO, | |
TAESDXL_MODEL_OPENVINO, | |
) | |
def get_tiny_decoder_vae_model(pipeline_class) -> str: | |
print(f"Pipeline class : {pipeline_class}") | |
if ( | |
pipeline_class == "LatentConsistencyModelPipeline" | |
or pipeline_class == "StableDiffusionPipeline" | |
or pipeline_class == "StableDiffusionImg2ImgPipeline" | |
): | |
return TAESD_MODEL | |
elif ( | |
pipeline_class == "StableDiffusionXLPipeline" | |
or pipeline_class == "StableDiffusionXLImg2ImgPipeline" | |
): | |
return TAESDXL_MODEL | |
elif ( | |
pipeline_class == "OVStableDiffusionPipeline" | |
or pipeline_class == "OVStableDiffusionImg2ImgPipeline" | |
): | |
return TAESD_MODEL_OPENVINO | |
elif pipeline_class == "OVStableDiffusionXLPipeline": | |
return TAESDXL_MODEL_OPENVINO | |
else: | |
raise Exception("No valid pipeline class found!") | |