sayakpaul HF staff commited on
Commit
3bd4a93
1 Parent(s): 89a6b3b

fix: textual inversion utility.

Browse files
Files changed (1) hide show
  1. convert.py +44 -16
convert.py CHANGED
@@ -35,26 +35,51 @@ def initialize_pt_models():
35
  return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
36
 
37
 
38
- def initialize_tf_models(text_encoder_weights: str, unet_weights: str):
39
- """Initializes the separate models of Stable Diffusion from KerasCV and downloads
40
- their pre-trained weights."""
 
 
41
  tf_sd_model = keras_cv.models.StableDiffusion(
42
  img_height=IMG_HEIGHT, img_width=IMG_WIDTH
43
  )
 
44
  if text_encoder_weights is None:
45
  tf_text_encoder = tf_sd_model.text_encoder
46
  else:
47
  tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
48
  MAX_SEQ_LENGTH, download_weights=False
49
  )
50
- tf_vae = tf_sd_model.image_encoder
51
  if unet_weights is None:
52
  tf_unet = tf_sd_model.diffusion_model
53
  else:
54
  tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
55
  IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
56
  )
57
- return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  def run_conversion(
@@ -69,11 +94,23 @@ def run_conversion(
69
  pt_unet,
70
  pt_safety_checker,
71
  ) = initialize_pt_models()
72
- tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models(
73
- text_encoder_weights, unet_weights
74
  )
75
  print("Pre-trained model weights downloaded.")
76
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  if text_encoder_weights is not None:
78
  print("Loading fine-tuned text encoder weights.")
79
  text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
@@ -90,15 +127,6 @@ def run_conversion(
90
  pt_unet.load_state_dict(unet_state_dict_from_tf)
91
  print("Populated PT UNet from TF weights.")
92
 
93
- if placeholder_token is not None:
94
- print("Adding the placeholder_token to CLIPTokenizer...")
95
- num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
96
- if num_added_tokens == 0:
97
- raise ValueError(
98
- f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
99
- " `placeholder_token` that is not already in the tokenizer."
100
- )
101
-
102
  print("Weights ported, preparing StabelDiffusionPipeline...")
103
  pipeline = StableDiffusionPipeline.from_pretrained(
104
  PRETRAINED_CKPT,
 
35
  return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
36
 
37
 
38
+ def initialize_tf_models(
39
+ text_encoder_weights: str, unet_weights: str, placeholder_token: str = None
40
+ ):
41
+ """Initializes the separate models of Stable Diffusion from KerasCV and optionally
42
+ downloads their pre-trained weights."""
43
  tf_sd_model = keras_cv.models.StableDiffusion(
44
  img_height=IMG_HEIGHT, img_width=IMG_WIDTH
45
  )
46
+
47
  if text_encoder_weights is None:
48
  tf_text_encoder = tf_sd_model.text_encoder
49
  else:
50
  tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
51
  MAX_SEQ_LENGTH, download_weights=False
52
  )
53
+
54
  if unet_weights is None:
55
  tf_unet = tf_sd_model.diffusion_model
56
  else:
57
  tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
58
  IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
59
  )
60
+
61
+ tf_tokenizer = tf_sd_model.tokenizer
62
+ if placeholder_token is not None:
63
+ tf_tokenizer.add_tokens(placeholder_token)
64
+
65
+ return tf_text_encoder, tf_unet, tf_tokenizer
66
+
67
+
68
+ def create_new_text_encoder(tf_text_encoder, tf_tokenizer):
69
+ """Initializes a fresh text encoder in case the weights are from Textual Inversion.
70
+
71
+ Reference: https://keras.io/examples/generative/fine_tune_via_textual_inversion/
72
+ """
73
+ new_vocab_size = len(tf_tokenizer.vocab)
74
+ new_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
75
+ MAX_SEQ_LENGTH,
76
+ vocab_size=new_vocab_size,
77
+ download_weights=False,
78
+ )
79
+
80
+ old_position_weights = tf_text_encoder.layers[2].position_embedding.get_weights()
81
+ new_text_encoder.layers[2].position_embedding.set_weights(old_position_weights)
82
+ return new_text_encoder
83
 
84
 
85
  def run_conversion(
 
94
  pt_unet,
95
  pt_safety_checker,
96
  ) = initialize_pt_models()
97
+ tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
98
+ text_encoder_weights, unet_weights, placeholder_token
99
  )
100
  print("Pre-trained model weights downloaded.")
101
 
102
+ if placeholder_token is not None:
103
+ print("Initializing a new text encoder with the placeholder token...")
104
+ tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)
105
+
106
+ print("Adding the placeholder token to PT CLIPTokenizer...")
107
+ num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
108
+ if num_added_tokens == 0:
109
+ raise ValueError(
110
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
111
+ " `placeholder_token` that is not already in the tokenizer."
112
+ )
113
+
114
  if text_encoder_weights is not None:
115
  print("Loading fine-tuned text encoder weights.")
116
  text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
 
127
  pt_unet.load_state_dict(unet_state_dict_from_tf)
128
  print("Populated PT UNet from TF weights.")
129
 
 
 
 
 
 
 
 
 
 
130
  print("Weights ported, preparing StabelDiffusionPipeline...")
131
  pipeline = StableDiffusionPipeline.from_pretrained(
132
  PRETRAINED_CKPT,