Spaces:
Build error
Build error
File size: 4,920 Bytes
3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from typing import Dict
import tensorflow as tf
import torch
from keras_cv.models import stable_diffusion
MAX_SEQ_LENGTH = 77
def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]:
"""Populates the state dict from the provided TensorFlow model
(applicable only for the text encoder)."""
text_state_dict = dict()
num_encoder_layers = 0
# Position ids.
text_state_dict["text_model.embeddings.position_ids"] = torch.tensor(
list(range(MAX_SEQ_LENGTH))
).unsqueeze(0)
for layer in tf_text_encoder.layers:
# Embeddings.
if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding):
text_state_dict[
"text_model.embeddings.token_embedding.weight"
] = torch.from_numpy(layer.token_embedding.get_weights()[0])
text_state_dict[
"text_model.embeddings.position_embedding.weight"
] = torch.from_numpy(layer.position_embedding.get_weights()[0])
# Encoder blocks.
elif isinstance(layer, stable_diffusion.text_encoder.CLIPEncoderLayer):
# LayerNorms
for i in range(1, 3):
if i == 1:
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.weight"
] = torch.from_numpy(layer.layer_norm1.get_weights()[0])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.bias"
] = torch.from_numpy(layer.layer_norm1.get_weights()[1])
else:
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.weight"
] = torch.from_numpy(layer.layer_norm2.get_weights()[0])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.bias"
] = torch.from_numpy(layer.layer_norm2.get_weights()[1])
# Attention.
q_proj = layer.clip_attn.q_proj
k_proj = layer.clip_attn.k_proj
v_proj = layer.clip_attn.v_proj
out_proj = layer.clip_attn.out_proj
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.weight"
] = torch.from_numpy(q_proj.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.bias"
] = torch.from_numpy(q_proj.get_weights()[1])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.weight"
] = torch.from_numpy(k_proj.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.bias"
] = torch.from_numpy(k_proj.get_weights()[1])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.weight"
] = torch.from_numpy(v_proj.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.bias"
] = torch.from_numpy(v_proj.get_weights()[1])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.weight"
] = torch.from_numpy(out_proj.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.bias"
] = torch.from_numpy(out_proj.get_weights()[1])
# MLPs.
fc1 = layer.fc1
fc2 = layer.fc2
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.weight"
] = torch.from_numpy(fc1.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.bias"
] = torch.from_numpy(fc1.get_weights()[1])
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.weight"
] = torch.from_numpy(fc2.get_weights()[0].transpose())
text_state_dict[
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.bias"
] = torch.from_numpy(fc2.get_weights()[1])
num_encoder_layers += 1
# Final LayerNorm.
elif isinstance(layer, tf.keras.layers.LayerNormalization):
text_state_dict["text_model.final_layer_norm.weight"] = torch.from_numpy(
layer.get_weights()[0]
)
text_state_dict["text_model.final_layer_norm.bias"] = torch.from_numpy(
layer.get_weights()[1]
)
return text_state_dict
|