Spaces:
Build error
Build error
from typing import Dict | |
import tensorflow as tf | |
import torch | |
from keras_cv.models import stable_diffusion | |
MAX_SEQ_LENGTH = 77 | |
def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]: | |
"""Populates the state dict from the provided TensorFlow model | |
(applicable only for the text encoder).""" | |
text_state_dict = dict() | |
num_encoder_layers = 0 | |
# Position ids. | |
text_state_dict["text_model.embeddings.position_ids"] = torch.tensor( | |
list(range(MAX_SEQ_LENGTH)) | |
).unsqueeze(0) | |
for layer in tf_text_encoder.layers: | |
# Embeddings. | |
if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding): | |
text_state_dict[ | |
"text_model.embeddings.token_embedding.weight" | |
] = torch.from_numpy(layer.token_embedding.get_weights()[0]) | |
text_state_dict[ | |
"text_model.embeddings.position_embedding.weight" | |
] = torch.from_numpy(layer.position_embedding.get_weights()[0]) | |
# Encoder blocks. | |
elif isinstance(layer, stable_diffusion.text_encoder.CLIPEncoderLayer): | |
# LayerNorms | |
for i in range(1, 3): | |
if i == 1: | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.weight" | |
] = torch.from_numpy(layer.layer_norm1.get_weights()[0]) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.bias" | |
] = torch.from_numpy(layer.layer_norm1.get_weights()[1]) | |
else: | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.weight" | |
] = torch.from_numpy(layer.layer_norm2.get_weights()[0]) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.bias" | |
] = torch.from_numpy(layer.layer_norm2.get_weights()[1]) | |
# Attention. | |
q_proj = layer.clip_attn.q_proj | |
k_proj = layer.clip_attn.k_proj | |
v_proj = layer.clip_attn.v_proj | |
out_proj = layer.clip_attn.out_proj | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.weight" | |
] = torch.from_numpy(q_proj.get_weights()[0].transpose()) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.bias" | |
] = torch.from_numpy(q_proj.get_weights()[1]) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.weight" | |
] = torch.from_numpy(k_proj.get_weights()[0].transpose()) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.bias" | |
] = torch.from_numpy(k_proj.get_weights()[1]) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.weight" | |
] = torch.from_numpy(v_proj.get_weights()[0].transpose()) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.bias" | |
] = torch.from_numpy(v_proj.get_weights()[1]) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.weight" | |
] = torch.from_numpy(out_proj.get_weights()[0].transpose()) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.bias" | |
] = torch.from_numpy(out_proj.get_weights()[1]) | |
# MLPs. | |
fc1 = layer.fc1 | |
fc2 = layer.fc2 | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.weight" | |
] = torch.from_numpy(fc1.get_weights()[0].transpose()) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.bias" | |
] = torch.from_numpy(fc1.get_weights()[1]) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.weight" | |
] = torch.from_numpy(fc2.get_weights()[0].transpose()) | |
text_state_dict[ | |
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.bias" | |
] = torch.from_numpy(fc2.get_weights()[1]) | |
num_encoder_layers += 1 | |
# Final LayerNorm. | |
elif isinstance(layer, tf.keras.layers.LayerNormalization): | |
text_state_dict["text_model.final_layer_norm.weight"] = torch.from_numpy( | |
layer.get_weights()[0] | |
) | |
text_state_dict["text_model.final_layer_norm.bias"] = torch.from_numpy( | |
layer.get_weights()[1] | |
) | |
return text_state_dict | |