sayakpaul's picture
sayakpaul HF staff
apply styling.
3304f7d
raw
history blame
22.9 kB
from itertools import product
from typing import Dict
import tensorflow as tf
import torch
from keras_cv.models import stable_diffusion
def port_transformer_block(
transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int
) -> Dict[str, torch.Tensor]:
"""Populates a Transformer block."""
transformer_dict = dict()
if block_id is not None:
prefix = f"{up_down}_blocks.{block_id}"
else:
prefix = "mid_block"
# Norms.
for i in range(1, 4):
if i == 1:
norm = transformer_block.norm1
elif i == 2:
norm = transformer_block.norm2
elif i == 3:
norm = transformer_block.norm3
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight"
] = torch.from_numpy(norm.get_weights()[0])
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias"
] = torch.from_numpy(norm.get_weights()[1])
# Attentions.
for i in range(1, 3):
if i == 1:
attn = transformer_block.attn1
else:
attn = transformer_block.attn2
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight"
] = torch.from_numpy(attn.to_q.get_weights()[0].transpose())
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight"
] = torch.from_numpy(attn.to_k.get_weights()[0].transpose())
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight"
] = torch.from_numpy(attn.to_v.get_weights()[0].transpose())
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight"
] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose())
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias"
] = torch.from_numpy(attn.out_proj.get_weights()[1])
# Dense.
for i in range(0, 3, 2):
if i == 0:
layer = transformer_block.geglu.dense
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight"
] = torch.from_numpy(layer.get_weights()[0].transpose())
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias"
] = torch.from_numpy(layer.get_weights()[1])
else:
layer = transformer_block.dense
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight"
] = torch.from_numpy(layer.get_weights()[0].transpose())
transformer_dict[
f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias"
] = torch.from_numpy(layer.get_weights()[1])
return transformer_dict
def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
"""Populates the state dict from the provided TensorFlow model
(applicable only for the UNet)."""
unet_state_dict = dict()
timstep_emb = 1
padded_conv = 1
up_block = 0
up_res_blocks = list(product([0, 1, 2, 3], [0, 1, 2]))
up_res_block_flag = 0
up_spatial_transformer_blocks = list(product([1, 2, 3], [0, 1, 2]))
up_spatial_transformer_flag = 0
for layer in tf_unet.layers:
# Timstep embedding.
if isinstance(layer, tf.keras.layers.Dense):
unet_state_dict[
f"time_embedding.linear_{timstep_emb}.weight"
] = torch.from_numpy(layer.get_weights()[0].transpose())
unet_state_dict[
f"time_embedding.linear_{timstep_emb}.bias"
] = torch.from_numpy(layer.get_weights()[1])
timstep_emb += 1
# Padded convs (downsamplers).
elif isinstance(
layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D
):
if padded_conv == 1:
# Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104
unet_state_dict["conv_in.weight"] = torch.from_numpy(
layer.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict["conv_in.bias"] = torch.from_numpy(
layer.get_weights()[1]
)
elif padded_conv in [2, 3, 4]:
unet_state_dict[
f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight"
] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
unet_state_dict[
f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias"
] = torch.from_numpy(layer.get_weights()[1])
elif padded_conv == 5:
unet_state_dict["conv_out.weight"] = torch.from_numpy(
layer.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict["conv_out.bias"] = torch.from_numpy(
layer.get_weights()[1]
)
padded_conv += 1
# Upsamplers.
elif isinstance(layer, stable_diffusion.diffusion_model.Upsample):
conv = layer.conv
unet_state_dict[
f"up_blocks.{up_block}.upsamplers.0.conv.weight"
] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1))
unet_state_dict[
f"up_blocks.{up_block}.upsamplers.0.conv.bias"
] = torch.from_numpy(conv.get_weights()[1])
up_block += 1
# Output norms.
elif isinstance(
layer,
stable_diffusion.__internal__.layers.group_normalization.GroupNormalization,
):
unet_state_dict["conv_norm_out.weight"] = torch.from_numpy(
layer.get_weights()[0]
)
unet_state_dict["conv_norm_out.bias"] = torch.from_numpy(
layer.get_weights()[1]
)
# All ResBlocks.
elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock):
layer_name = layer.name
parts = layer_name.split("_")
# Down.
if len(parts) == 2 or int(parts[-1]) < 8:
entry_flow = layer.entry_flow
embedding_flow = layer.embedding_flow
exit_flow = layer.exit_flow
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
down_resnet_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
# Conv blocks.
first_conv_layer = entry_flow[-1]
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight"
] = torch.from_numpy(
first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias"
] = torch.from_numpy(first_conv_layer.get_weights()[1])
second_conv_layer = exit_flow[-1]
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight"
] = torch.from_numpy(
second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias"
] = torch.from_numpy(second_conv_layer.get_weights()[1])
# Residual blocks.
if hasattr(layer, "residual_projection"):
if isinstance(
layer.residual_projection,
stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
):
residual = layer.residual_projection
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight"
] = torch.from_numpy(
residual.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias"
] = torch.from_numpy(residual.get_weights()[1])
# Timestep embedding.
embedding_proj = embedding_flow[-1]
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight"
] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias"
] = torch.from_numpy(embedding_proj.get_weights()[1])
# Norms.
first_group_norm = entry_flow[0]
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight"
] = torch.from_numpy(first_group_norm.get_weights()[0])
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias"
] = torch.from_numpy(first_group_norm.get_weights()[1])
second_group_norm = exit_flow[0]
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight"
] = torch.from_numpy(second_group_norm.get_weights()[0])
unet_state_dict[
f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias"
] = torch.from_numpy(second_group_norm.get_weights()[1])
# Middle.
elif int(parts[-1]) == 8 or int(parts[-1]) == 9:
entry_flow = layer.entry_flow
embedding_flow = layer.embedding_flow
exit_flow = layer.exit_flow
mid_resnet_id = int(parts[-1]) % 2
# Conv blocks.
first_conv_layer = entry_flow[-1]
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.conv1.weight"
] = torch.from_numpy(
first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.conv1.bias"
] = torch.from_numpy(first_conv_layer.get_weights()[1])
second_conv_layer = exit_flow[-1]
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.conv2.weight"
] = torch.from_numpy(
second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.conv2.bias"
] = torch.from_numpy(second_conv_layer.get_weights()[1])
# Residual blocks.
if hasattr(layer, "residual_projection"):
if isinstance(
layer.residual_projection,
stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
):
residual = layer.residual_projection
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight"
] = torch.from_numpy(
residual.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias"
] = torch.from_numpy(residual.get_weights()[1])
# Timestep embedding.
embedding_proj = embedding_flow[-1]
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight"
] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias"
] = torch.from_numpy(embedding_proj.get_weights()[1])
# Norms.
first_group_norm = entry_flow[0]
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.norm1.weight"
] = torch.from_numpy(first_group_norm.get_weights()[0])
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.norm1.bias"
] = torch.from_numpy(first_group_norm.get_weights()[1])
second_group_norm = exit_flow[0]
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.norm2.weight"
] = torch.from_numpy(second_group_norm.get_weights()[0])
unet_state_dict[
f"mid_block.resnets.{mid_resnet_id}.norm2.bias"
] = torch.from_numpy(second_group_norm.get_weights()[1])
# Up.
elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks):
entry_flow = layer.entry_flow
embedding_flow = layer.embedding_flow
exit_flow = layer.exit_flow
up_res_block = up_res_blocks[up_res_block_flag]
up_block_id = up_res_block[0]
up_resnet_id = up_res_block[1]
# Conv blocks.
first_conv_layer = entry_flow[-1]
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight"
] = torch.from_numpy(
first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias"
] = torch.from_numpy(first_conv_layer.get_weights()[1])
second_conv_layer = exit_flow[-1]
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight"
] = torch.from_numpy(
second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias"
] = torch.from_numpy(second_conv_layer.get_weights()[1])
# Residual blocks.
if hasattr(layer, "residual_projection"):
if isinstance(
layer.residual_projection,
stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D,
):
residual = layer.residual_projection
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight"
] = torch.from_numpy(
residual.get_weights()[0].transpose(3, 2, 0, 1)
)
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias"
] = torch.from_numpy(residual.get_weights()[1])
# Timestep embedding.
embedding_proj = embedding_flow[-1]
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight"
] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias"
] = torch.from_numpy(embedding_proj.get_weights()[1])
# Norms.
first_group_norm = entry_flow[0]
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight"
] = torch.from_numpy(first_group_norm.get_weights()[0])
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias"
] = torch.from_numpy(first_group_norm.get_weights()[1])
second_group_norm = exit_flow[0]
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight"
] = torch.from_numpy(second_group_norm.get_weights()[0])
unet_state_dict[
f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias"
] = torch.from_numpy(second_group_norm.get_weights()[1])
up_res_block_flag += 1
# All SpatialTransformer blocks.
elif isinstance(layer, stable_diffusion.diffusion_model.SpatialTransformer):
layer_name = layer.name
parts = layer_name.split("_")
# Down.
if len(parts) == 2 or int(parts[-1]) < 6:
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
# Convs.
proj1 = layer.proj1
unet_state_dict[
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight"
] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
unet_state_dict[
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias"
] = torch.from_numpy(proj1.get_weights()[1])
proj2 = layer.proj2
unet_state_dict[
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight"
] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
unet_state_dict[
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias"
] = torch.from_numpy(proj2.get_weights()[1])
# Transformer blocks.
transformer_block = layer.transformer_block
unet_state_dict.update(
port_transformer_block(
transformer_block, "down", down_block_id, down_attention_id
)
)
# Norms.
norm = layer.norm
unet_state_dict[
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight"
] = torch.from_numpy(norm.get_weights()[0])
unet_state_dict[
f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias"
] = torch.from_numpy(norm.get_weights()[1])
# Middle.
elif int(parts[-1]) == 6:
mid_attention_id = int(parts[-1]) % 2
# Convs.
proj1 = layer.proj1
unet_state_dict[
f"mid_block.attentions.{mid_attention_id}.proj_in.weight"
] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
unet_state_dict[
f"mid_block.attentions.{mid_attention_id}.proj_in.bias"
] = torch.from_numpy(proj1.get_weights()[1])
proj2 = layer.proj2
unet_state_dict[
f"mid_block.attentions.{mid_resnet_id}.proj_out.weight"
] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
unet_state_dict[
f"mid_block.attentions.{mid_attention_id}.proj_out.bias"
] = torch.from_numpy(proj2.get_weights()[1])
# Transformer blocks.
transformer_block = layer.transformer_block
unet_state_dict.update(
port_transformer_block(
transformer_block, "mid", None, mid_attention_id
)
)
# Norms.
norm = layer.norm
unet_state_dict[
f"mid_block.attentions.{mid_attention_id}.norm.weight"
] = torch.from_numpy(norm.get_weights()[0])
unet_state_dict[
f"mid_block.attentions.{mid_attention_id}.norm.bias"
] = torch.from_numpy(norm.get_weights()[1])
# Up.
elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(
up_spatial_transformer_blocks
):
up_spatial_transformer_block = up_spatial_transformer_blocks[
up_spatial_transformer_flag
]
up_block_id = up_spatial_transformer_block[0]
up_attention_id = up_spatial_transformer_block[1]
# Convs.
proj1 = layer.proj1
unet_state_dict[
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight"
] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
unet_state_dict[
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias"
] = torch.from_numpy(proj1.get_weights()[1])
proj2 = layer.proj2
unet_state_dict[
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight"
] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
unet_state_dict[
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias"
] = torch.from_numpy(proj2.get_weights()[1])
# Transformer blocks.
transformer_block = layer.transformer_block
unet_state_dict.update(
port_transformer_block(
transformer_block, "up", up_block_id, up_attention_id
)
)
# Norms.
norm = layer.norm
unet_state_dict[
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight"
] = torch.from_numpy(norm.get_weights()[0])
unet_state_dict[
f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias"
] = torch.from_numpy(norm.get_weights()[1])
up_spatial_transformer_flag += 1
return unet_state_dict