sayakpaul's picture
sayakpaul HF staff
apply styling.
3304f7d
raw
history blame
4.92 kB
from typing import Dict
import tensorflow as tf
import torch
from keras_cv.models import stable_diffusion
MAX_SEQ_LENGTH = 77
def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]:
"""Populates the state dict from the provided TensorFlow model
(applicable only for the text encoder)."""
text_state_dict = dict()
num_encoder_layers = 0
# Position ids.
text_state_dict["text_model.embeddings.position_ids"] = torch.tensor(
list(range(MAX_SEQ_LENGTH))
).unsqueeze(0)
for layer in tf_text_encoder.layers:
# Embeddings.
if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding):
text_state_dict[
"text_model.embeddings.token_embedding.weight"
] = torch.from_numpy(layer.token_embedding.get_weights()[0])
text_state_dict[
"text_model.embeddings.position_embedding.weight"
] = torch.from_numpy(layer.position_embedding.get_weights()[0])
# Encoder blocks.
elif isinstance(layer, stable_diffusion.text_encoder.CLIPEncoderLayer):
# LayerNorms
for i in range(1, 3):
if i == 1:
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.weight"
] = torch.from_numpy(layer.layer_norm1.get_weights()[0])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.bias"
] = torch.from_numpy(layer.layer_norm1.get_weights()[1])
else:
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.weight"
] = torch.from_numpy(layer.layer_norm2.get_weights()[0])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.bias"
] = torch.from_numpy(layer.layer_norm2.get_weights()[1])
# Attention.
q_proj = layer.clip_attn.q_proj
k_proj = layer.clip_attn.k_proj
v_proj = layer.clip_attn.v_proj
out_proj = layer.clip_attn.out_proj
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.weight"
] = torch.from_numpy(q_proj.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.bias"
] = torch.from_numpy(q_proj.get_weights()[1])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.weight"
] = torch.from_numpy(k_proj.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.bias"
] = torch.from_numpy(k_proj.get_weights()[1])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.weight"
] = torch.from_numpy(v_proj.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.bias"
] = torch.from_numpy(v_proj.get_weights()[1])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.weight"
] = torch.from_numpy(out_proj.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.bias"
] = torch.from_numpy(out_proj.get_weights()[1])
# MLPs.
fc1 = layer.fc1
fc2 = layer.fc2
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.weight"
] = torch.from_numpy(fc1.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.bias"
] = torch.from_numpy(fc1.get_weights()[1])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.weight"
] = torch.from_numpy(fc2.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.bias"
] = torch.from_numpy(fc2.get_weights()[1])
num_encoder_layers += 1
# Final LayerNorm.
elif isinstance(layer, tf.keras.layers.LayerNormalization):
text_state_dict["text_model.final_layer_norm.weight"] = torch.from_numpy(
layer.get_weights()[0]
)
text_state_dict["text_model.final_layer_norm.bias"] = torch.from_numpy(
layer.get_weights()[1]
)
return text_state_dict