Spaces:
Build error
Build error
fix: textual inversion utility.
Browse files- convert.py +44 -16
convert.py
CHANGED
@@ -35,26 +35,51 @@ def initialize_pt_models():
|
|
35 |
return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
|
36 |
|
37 |
|
38 |
-
def initialize_tf_models(
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
tf_sd_model = keras_cv.models.StableDiffusion(
|
42 |
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
|
43 |
)
|
|
|
44 |
if text_encoder_weights is None:
|
45 |
tf_text_encoder = tf_sd_model.text_encoder
|
46 |
else:
|
47 |
tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
|
48 |
MAX_SEQ_LENGTH, download_weights=False
|
49 |
)
|
50 |
-
|
51 |
if unet_weights is None:
|
52 |
tf_unet = tf_sd_model.diffusion_model
|
53 |
else:
|
54 |
tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
|
55 |
IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
|
56 |
)
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
def run_conversion(
|
@@ -69,11 +94,23 @@ def run_conversion(
|
|
69 |
pt_unet,
|
70 |
pt_safety_checker,
|
71 |
) = initialize_pt_models()
|
72 |
-
|
73 |
-
text_encoder_weights, unet_weights
|
74 |
)
|
75 |
print("Pre-trained model weights downloaded.")
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
if text_encoder_weights is not None:
|
78 |
print("Loading fine-tuned text encoder weights.")
|
79 |
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
|
@@ -90,15 +127,6 @@ def run_conversion(
|
|
90 |
pt_unet.load_state_dict(unet_state_dict_from_tf)
|
91 |
print("Populated PT UNet from TF weights.")
|
92 |
|
93 |
-
if placeholder_token is not None:
|
94 |
-
print("Adding the placeholder_token to CLIPTokenizer...")
|
95 |
-
num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
|
96 |
-
if num_added_tokens == 0:
|
97 |
-
raise ValueError(
|
98 |
-
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
99 |
-
" `placeholder_token` that is not already in the tokenizer."
|
100 |
-
)
|
101 |
-
|
102 |
print("Weights ported, preparing StabelDiffusionPipeline...")
|
103 |
pipeline = StableDiffusionPipeline.from_pretrained(
|
104 |
PRETRAINED_CKPT,
|
|
|
35 |
return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
|
36 |
|
37 |
|
38 |
+
def initialize_tf_models(
|
39 |
+
text_encoder_weights: str, unet_weights: str, placeholder_token: str = None
|
40 |
+
):
|
41 |
+
"""Initializes the separate models of Stable Diffusion from KerasCV and optionally
|
42 |
+
downloads their pre-trained weights."""
|
43 |
tf_sd_model = keras_cv.models.StableDiffusion(
|
44 |
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
|
45 |
)
|
46 |
+
|
47 |
if text_encoder_weights is None:
|
48 |
tf_text_encoder = tf_sd_model.text_encoder
|
49 |
else:
|
50 |
tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
|
51 |
MAX_SEQ_LENGTH, download_weights=False
|
52 |
)
|
53 |
+
|
54 |
if unet_weights is None:
|
55 |
tf_unet = tf_sd_model.diffusion_model
|
56 |
else:
|
57 |
tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
|
58 |
IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
|
59 |
)
|
60 |
+
|
61 |
+
tf_tokenizer = tf_sd_model.tokenizer
|
62 |
+
if placeholder_token is not None:
|
63 |
+
tf_tokenizer.add_tokens(placeholder_token)
|
64 |
+
|
65 |
+
return tf_text_encoder, tf_unet, tf_tokenizer
|
66 |
+
|
67 |
+
|
68 |
+
def create_new_text_encoder(tf_text_encoder, tf_tokenizer):
|
69 |
+
"""Initializes a fresh text encoder in case the weights are from Textual Inversion.
|
70 |
+
|
71 |
+
Reference: https://keras.io/examples/generative/fine_tune_via_textual_inversion/
|
72 |
+
"""
|
73 |
+
new_vocab_size = len(tf_tokenizer.vocab)
|
74 |
+
new_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
|
75 |
+
MAX_SEQ_LENGTH,
|
76 |
+
vocab_size=new_vocab_size,
|
77 |
+
download_weights=False,
|
78 |
+
)
|
79 |
+
|
80 |
+
old_position_weights = tf_text_encoder.layers[2].position_embedding.get_weights()
|
81 |
+
new_text_encoder.layers[2].position_embedding.set_weights(old_position_weights)
|
82 |
+
return new_text_encoder
|
83 |
|
84 |
|
85 |
def run_conversion(
|
|
|
94 |
pt_unet,
|
95 |
pt_safety_checker,
|
96 |
) = initialize_pt_models()
|
97 |
+
tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
|
98 |
+
text_encoder_weights, unet_weights, placeholder_token
|
99 |
)
|
100 |
print("Pre-trained model weights downloaded.")
|
101 |
|
102 |
+
if placeholder_token is not None:
|
103 |
+
print("Initializing a new text encoder with the placeholder token...")
|
104 |
+
tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)
|
105 |
+
|
106 |
+
print("Adding the placeholder token to PT CLIPTokenizer...")
|
107 |
+
num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
|
108 |
+
if num_added_tokens == 0:
|
109 |
+
raise ValueError(
|
110 |
+
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
111 |
+
" `placeholder_token` that is not already in the tokenizer."
|
112 |
+
)
|
113 |
+
|
114 |
if text_encoder_weights is not None:
|
115 |
print("Loading fine-tuned text encoder weights.")
|
116 |
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
|
|
|
127 |
pt_unet.load_state_dict(unet_state_dict_from_tf)
|
128 |
print("Populated PT UNet from TF weights.")
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
print("Weights ported, preparing StabelDiffusionPipeline...")
|
131 |
pipeline = StableDiffusionPipeline.from_pretrained(
|
132 |
PRETRAINED_CKPT,
|