Spaces:
Build error
Build error
File size: 3,741 Bytes
ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 94913a9 ddc8a59 9435d99 ddc8a59 94913a9 ddc8a59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import keras_cv
import tensorflow as tf
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
UNet2DConditionModel)
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
from transformers import CLIPTextModel
from conversion_utils import (populate_text_encoder, populate_unet,
run_assertion)
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
REVISION = None
NON_EMA_REVISION = None
IMG_HEIGHT = IMG_WIDTH = 512
def initialize_pt_models():
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
their pre-trained weights."""
pt_text_encoder = CLIPTextModel.from_pretrained(
PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
)
pt_vae = AutoencoderKL.from_pretrained(
PRETRAINED_CKPT, subfolder="vae", revision=REVISION
)
pt_unet = UNet2DConditionModel.from_pretrained(
PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
)
pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
)
return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
def initialize_tf_models():
"""Initializes the separate models of Stable Diffusion from KerasCV and downloads
their pre-trained weights."""
tf_sd_model = keras_cv.models.StableDiffusion(
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
)
_ = tf_sd_model.text_to_image("Cartoon") # To download the weights.
tf_text_encoder = tf_sd_model.text_encoder
tf_vae = tf_sd_model.image_encoder
tf_unet = tf_sd_model.diffusion_model
return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models()
print("Pre-trained model weights downloaded.")
if text_encoder_weights is not None:
print("Loading fine-tuned text encoder weights.")
text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
tf_text_encoder.load_weights(text_encoder_weights_path)
if unet_weights is not None:
print("Loading fine-tuned UNet weights.")
unet_weights_path = tf.keras.utils.get_file(unet_weights)
tf_unet.load_weights(unet_weights_path)
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
unet_state_dict_from_tf = populate_unet(tf_unet)
print("Conversion done, now running optional assertions...")
# Since we cannot compare the fine-tuned weights.
if text_encoder_weights is None:
text_encoder_state_dict_from_pt = pt_text_encoder.state_dict()
run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf)
if unet_weights is None:
unet_state_dict_from_pt = pt_unet.state_dict()
run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
if text_encoder_weights is None or unet_weights is None:
print(
"Assertions successful, populating the converted parameters into the diffusers models..."
)
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
pt_unet.load_state_dict(unet_state_dict_from_tf)
print("Parameters ported, preparing StabelDiffusionPipeline...")
pipeline = StableDiffusionPipeline.from_pretrained(
PRETRAINED_CKPT,
unet=pt_unet,
text_encoder=pt_text_encoder,
vae=pt_vae,
safety_checker=pt_safety_checker,
revision=None,
)
return pipeline
|