Spaces:
Build error
Build error
add: files.
Browse files- README.md +3 -3
- app.py +44 -0
- conversion_utils/__init__.py +3 -0
- conversion_utils/text_encoder.py +110 -0
- conversion_utils/unet.py +291 -0
- conversion_utils/utils.py +15 -0
- convert.py +90 -0
- hub_utils/__init__.py +2 -0
- hub_utils/readme.py +29 -0
- hub_utils/repo.py +15 -0
- requirements.txt +7 -0
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
-
title: Convert Kerascv
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.16.2
|
|
|
1 |
---
|
2 |
+
title: Convert Kerascv SD to Diffusers
|
3 |
+
emoji: 🧨
|
4 |
+
colorFrom: red
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.16.2
|
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from convert import run_conversion
|
3 |
+
from hub_utils import save_model_card, push_to_hub
|
4 |
+
|
5 |
+
|
6 |
+
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
7 |
+
DESCRIPTION = """
|
8 |
+
This Space lets you convert KerasCV Stable Diffusion weights to a format compatible with [Diffusers](https://github.com/huggingface/diffusers) 🧨. This allows users to fine-tune using KerasCV and use the fine-tuned weights in Diffusers taking advantage of its nifty features (like schedulers, fast attention, etc.). Specifically, the parameters are converted and then they are wrapped into a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview). This pipeline is then pushed to the Hugging Face Hub given you have provided a `your_hf_token`.
|
9 |
+
|
10 |
+
## Notes (important)
|
11 |
+
|
12 |
+
* Only Stable Diffusion (v1) is supported as of now. In particular this checkpoint: [`"CompVis/stable-diffusion-v1-4"`](https://huggingface.co/CompVis/stable-diffusion-v1-4).
|
13 |
+
* Only the text encoder and the UNet parameters converted since only these two elements are generally fine-tuned.
|
14 |
+
* [This Colab Notebook](https://colab.research.google.com/drive/1RYY077IQbAJldg8FkK8HSEpNILKHEwLb?usp=sharing) was used to develop the conversion utilities initially.
|
15 |
+
* You can choose not to provide `text_encoder_weights` and `unet_weights` in case you don't have any fine-tuned weights. In that case, the original parameters of the respective models (text encoder and UNet) from KerasCV will be used.
|
16 |
+
* You can provide only `text_encoder_weights` or `unet_weights` or both.
|
17 |
+
* When providing the weights' links, ensure they're directly downloadable. Internally, the Space uses [`tf.keras.utils.get_file()`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/get_file) to retrieve the weights locally.
|
18 |
+
* If you don't provide `your_hf_token` the converted pipeline won't be pushed.
|
19 |
+
|
20 |
+
Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/dreambooth/train_dreambooth_lora.py#L975) for an example on how you can change the scheduler of an already initialized pipeline.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
|
24 |
+
if text_encoder_weights == "":
|
25 |
+
text_encoder_weights = None
|
26 |
+
if unet_weights == "":
|
27 |
+
unet_weights = None
|
28 |
+
pipeline = run_conversion(text_encoder_weights, unet_weights)
|
29 |
+
output_path = "kerascv_sd_diffusers_pipeline"
|
30 |
+
pipeline.save_pretrained(output_path)
|
31 |
+
save_model_card(base_model=PRETRAINED_CKPT, repo_folder=output_path, weight_paths=[text_encoder_weights, unet_weights], repo_prefix=repo_prefix)
|
32 |
+
push_str = push_to_hub(hf_token, output_path, repo_prefix)
|
33 |
+
return push_str
|
34 |
+
|
35 |
+
demo = gr.Interface(
|
36 |
+
title="KerasCV Stable Diffusion to Diffusers Stable Diffusion Pipelines 🧨🤗",
|
37 |
+
description=DESCRIPTION,
|
38 |
+
allow_flagging="never",
|
39 |
+
inputs=[gr.Text(max_lines=1, label="your_hf_token"), gr.Text(max_lines=1, label="text_encoder_weights"), gr.Text(max_lines=1, label="unet_weights"), gr.Text(max_lines=1, label="output_repo_prefix")],
|
40 |
+
outputs=[gr.Markdown(label="output")],
|
41 |
+
fn=run,
|
42 |
+
)
|
43 |
+
|
44 |
+
demo.launch()
|
conversion_utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .text_encoder import populate_text_encoder
|
2 |
+
from .unet import populate_unet
|
3 |
+
from .utils import run_assertion
|
conversion_utils/text_encoder.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from keras_cv.models import stable_diffusion
|
2 |
+
import tensorflow as tf
|
3 |
+
import torch
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
MAX_SEQ_LENGTH = 77
|
7 |
+
|
8 |
+
def populate_text_encoder(tf_text_encoder: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
9 |
+
"""Populates the state dict from the provided TensorFlow model
|
10 |
+
(applicable only for the text encoder)."""
|
11 |
+
text_state_dict = dict()
|
12 |
+
num_encoder_layers = 0
|
13 |
+
|
14 |
+
for layer in tf_text_encoder.layers:
|
15 |
+
# Embeddings.
|
16 |
+
if isinstance(layer, stable_diffusion.text_encoder.CLIPEmbedding):
|
17 |
+
text_state_dict[
|
18 |
+
"text_model.embeddings.token_embedding.weight"
|
19 |
+
] = torch.from_numpy(layer.token_embedding.get_weights()[0])
|
20 |
+
text_state_dict[
|
21 |
+
"text_model.embeddings.position_embedding.weight"
|
22 |
+
] = torch.from_numpy(layer.position_embedding.get_weights()[0])
|
23 |
+
|
24 |
+
# Encoder blocks.
|
25 |
+
elif isinstance(layer, stable_diffusion.text_encoder.CLIPEncoderLayer):
|
26 |
+
# LayerNorms
|
27 |
+
for i in range(1, 3):
|
28 |
+
if i == 1:
|
29 |
+
text_state_dict[
|
30 |
+
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.weight"
|
31 |
+
] = torch.from_numpy(layer.layer_norm1.get_weights()[0])
|
32 |
+
text_state_dict[
|
33 |
+
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm1.bias"
|
34 |
+
] = torch.from_numpy(layer.layer_norm1.get_weights()[1])
|
35 |
+
else:
|
36 |
+
text_state_dict[
|
37 |
+
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.weight"
|
38 |
+
] = torch.from_numpy(layer.layer_norm2.get_weights()[0])
|
39 |
+
text_state_dict[
|
40 |
+
f"text_model.encoder.layers.{num_encoder_layers}.layer_norm2.bias"
|
41 |
+
] = torch.from_numpy(layer.layer_norm2.get_weights()[1])
|
42 |
+
|
43 |
+
# Attention.
|
44 |
+
q_proj = layer.clip_attn.q_proj
|
45 |
+
k_proj = layer.clip_attn.k_proj
|
46 |
+
v_proj = layer.clip_attn.v_proj
|
47 |
+
out_proj = layer.clip_attn.out_proj
|
48 |
+
|
49 |
+
text_state_dict[
|
50 |
+
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.weight"
|
51 |
+
] = torch.from_numpy(q_proj.get_weights()[0].transpose())
|
52 |
+
text_state_dict[
|
53 |
+
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.q_proj.bias"
|
54 |
+
] = torch.from_numpy(q_proj.get_weights()[1])
|
55 |
+
|
56 |
+
text_state_dict[
|
57 |
+
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.weight"
|
58 |
+
] = torch.from_numpy(k_proj.get_weights()[0].transpose())
|
59 |
+
text_state_dict[
|
60 |
+
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.k_proj.bias"
|
61 |
+
] = torch.from_numpy(k_proj.get_weights()[1])
|
62 |
+
|
63 |
+
text_state_dict[
|
64 |
+
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.weight"
|
65 |
+
] = torch.from_numpy(v_proj.get_weights()[0].transpose())
|
66 |
+
text_state_dict[
|
67 |
+
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.v_proj.bias"
|
68 |
+
] = torch.from_numpy(v_proj.get_weights()[1])
|
69 |
+
|
70 |
+
text_state_dict[
|
71 |
+
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.weight"
|
72 |
+
] = torch.from_numpy(out_proj.get_weights()[0].transpose())
|
73 |
+
text_state_dict[
|
74 |
+
f"text_model.encoder.layers.{num_encoder_layers}.self_attn.out_proj.bias"
|
75 |
+
] = torch.from_numpy(out_proj.get_weights()[1])
|
76 |
+
|
77 |
+
# MLPs.
|
78 |
+
fc1 = layer.fc1
|
79 |
+
fc2 = layer.fc2
|
80 |
+
|
81 |
+
text_state_dict[
|
82 |
+
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.weight"
|
83 |
+
] = torch.from_numpy(fc1.get_weights()[0].transpose())
|
84 |
+
text_state_dict[
|
85 |
+
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc1.bias"
|
86 |
+
] = torch.from_numpy(fc1.get_weights()[1])
|
87 |
+
text_state_dict[
|
88 |
+
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.weight"
|
89 |
+
] = torch.from_numpy(fc2.get_weights()[0].transpose())
|
90 |
+
text_state_dict[
|
91 |
+
f"text_model.encoder.layers.{num_encoder_layers}.mlp.fc2.bias"
|
92 |
+
] = torch.from_numpy(fc2.get_weights()[1])
|
93 |
+
|
94 |
+
num_encoder_layers += 1
|
95 |
+
|
96 |
+
# Final LayerNorm.
|
97 |
+
elif isinstance(layer, tf.keras.layers.LayerNormalization):
|
98 |
+
text_state_dict["text_model.final_layer_norm.weight"] = torch.from_numpy(
|
99 |
+
layer.get_weights()[0]
|
100 |
+
)
|
101 |
+
text_state_dict["text_model.final_layer_norm.bias"] = torch.from_numpy(
|
102 |
+
layer.get_weights()[1]
|
103 |
+
)
|
104 |
+
|
105 |
+
# Position ids.
|
106 |
+
text_state_dict["text_model.embeddings.position_ids"] = torch.tensor(
|
107 |
+
list(range(77))
|
108 |
+
).unsqueeze(0)
|
109 |
+
|
110 |
+
return text_state_dict
|
conversion_utils/unet.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import torch
|
3 |
+
from typing import Dict
|
4 |
+
from itertools import product
|
5 |
+
from keras_cv.models import stable_diffusion
|
6 |
+
|
7 |
+
def port_transformer_block(transformer_block: tf.keras.Model, up_down: int, block_id: int, attention_id: int) -> Dict[str, torch.Tensor]:
|
8 |
+
"""Populates a Transformer block."""
|
9 |
+
transformer_dict = dict()
|
10 |
+
if block_id is not None:
|
11 |
+
prefix = f"{up_down}_blocks.{block_id}"
|
12 |
+
else:
|
13 |
+
prefix = "mid_block"
|
14 |
+
|
15 |
+
# Norms.
|
16 |
+
for i in range(1, 4):
|
17 |
+
if i == 1:
|
18 |
+
norm = transformer_block.norm1
|
19 |
+
elif i == 2:
|
20 |
+
norm = transformer_block.norm2
|
21 |
+
elif i == 3:
|
22 |
+
norm = transformer_block.norm3
|
23 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.weight"] = torch.from_numpy(norm.get_weights()[0])
|
24 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.norm{i}.bias"] = torch.from_numpy(norm.get_weights()[1])
|
25 |
+
|
26 |
+
# Attentions.
|
27 |
+
for i in range(1, 3):
|
28 |
+
if i == 1:
|
29 |
+
attn = transformer_block.attn1
|
30 |
+
else:
|
31 |
+
attn = transformer_block.attn2
|
32 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_q.weight"] = torch.from_numpy(attn.to_q.get_weights()[0].transpose())
|
33 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_k.weight"] = torch.from_numpy(attn.to_k.get_weights()[0].transpose())
|
34 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_v.weight"] = torch.from_numpy(attn.to_v.get_weights()[0].transpose())
|
35 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.weight"] = torch.from_numpy(attn.out_proj.get_weights()[0].transpose())
|
36 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.attn{i}.to_out.0.bias"] = torch.from_numpy(attn.out_proj.get_weights()[1])
|
37 |
+
|
38 |
+
# Dense.
|
39 |
+
for i in range(0, 3, 2):
|
40 |
+
if i == 0:
|
41 |
+
layer = transformer_block.geglu.dense
|
42 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
|
43 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.proj.bias"] = torch.from_numpy(layer.get_weights()[1])
|
44 |
+
else:
|
45 |
+
layer = transformer_block.dense
|
46 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
|
47 |
+
transformer_dict[f"{prefix}.attentions.{attention_id}.transformer_blocks.0.ff.net.{i}.bias"] = torch.from_numpy(layer.get_weights()[1])
|
48 |
+
|
49 |
+
return transformer_dict
|
50 |
+
|
51 |
+
|
52 |
+
def populate_unet(tf_unet: tf.keras.Model) -> Dict[str, torch.Tensor]:
|
53 |
+
"""Populates the state dict from the provided TensorFlow model
|
54 |
+
(applicable only for the UNet)."""
|
55 |
+
unet_state_dict = dict()
|
56 |
+
|
57 |
+
timstep_emb = 1
|
58 |
+
padded_conv = 1
|
59 |
+
up_block = 0
|
60 |
+
|
61 |
+
up_res_blocks = list(product([0, 1, 2, 3], [0, 1, 2]))
|
62 |
+
up_res_block_flag = 0
|
63 |
+
|
64 |
+
up_spatial_transformer_blocks = list(product([1, 2, 3], [0, 1, 2]))
|
65 |
+
up_spatial_transformer_flag = 0
|
66 |
+
|
67 |
+
for layer in tf_unet.layers:
|
68 |
+
# Timstep embedding.
|
69 |
+
if isinstance(layer, tf.keras.layers.Dense):
|
70 |
+
unet_state_dict[f"time_embedding.linear_{timstep_emb}.weight"] = torch.from_numpy(layer.get_weights()[0].transpose())
|
71 |
+
unet_state_dict[f"time_embedding.linear_{timstep_emb}.bias"] = torch.from_numpy(layer.get_weights()[1])
|
72 |
+
timstep_emb += 1
|
73 |
+
|
74 |
+
# Padded convs (downsamplers).
|
75 |
+
elif isinstance(layer, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
|
76 |
+
if padded_conv == 1:
|
77 |
+
# Transposition axes taken from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_pytorch_utils.py#L104
|
78 |
+
unet_state_dict["conv_in.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
|
79 |
+
unet_state_dict["conv_in.bias"] = torch.from_numpy(layer.get_weights()[1])
|
80 |
+
elif padded_conv in [2, 3, 4]:
|
81 |
+
unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
|
82 |
+
unet_state_dict[f"down_blocks.{padded_conv-2}.downsamplers.0.conv.bias"] = torch.from_numpy(layer.get_weights()[1])
|
83 |
+
elif padded_conv == 5:
|
84 |
+
unet_state_dict["conv_out.weight"] = torch.from_numpy(layer.get_weights()[0].transpose(3, 2, 0, 1))
|
85 |
+
unet_state_dict["conv_out.bias"] = torch.from_numpy(layer.get_weights()[1])
|
86 |
+
|
87 |
+
padded_conv += 1
|
88 |
+
|
89 |
+
# Upsamplers.
|
90 |
+
elif isinstance(layer, stable_diffusion.diffusion_model.Upsample):
|
91 |
+
conv = layer.conv
|
92 |
+
unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.weight"] = torch.from_numpy(conv.get_weights()[0].transpose(3, 2, 0, 1))
|
93 |
+
unet_state_dict[f"up_blocks.{up_block}.upsamplers.0.conv.bias"] = torch.from_numpy(conv.get_weights()[1])
|
94 |
+
up_block += 1
|
95 |
+
|
96 |
+
# Output norms.
|
97 |
+
elif isinstance(layer, stable_diffusion.__internal__.layers.group_normalization.GroupNormalization):
|
98 |
+
unet_state_dict["conv_norm_out.weight"] = torch.from_numpy(layer.get_weights()[0])
|
99 |
+
unet_state_dict["conv_norm_out.bias"] = torch.from_numpy(layer.get_weights()[1])
|
100 |
+
|
101 |
+
# All ResBlocks.
|
102 |
+
elif isinstance(layer, stable_diffusion.diffusion_model.ResBlock):
|
103 |
+
layer_name = layer.name
|
104 |
+
parts = layer_name.split("_")
|
105 |
+
|
106 |
+
# Down.
|
107 |
+
if len(parts) == 2 or int(parts[-1]) < 8:
|
108 |
+
entry_flow = layer.entry_flow
|
109 |
+
embedding_flow = layer.embedding_flow
|
110 |
+
exit_flow = layer.exit_flow
|
111 |
+
|
112 |
+
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
|
113 |
+
down_resnet_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
|
114 |
+
|
115 |
+
# Conv blocks.
|
116 |
+
first_conv_layer = entry_flow[-1]
|
117 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
|
118 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
|
119 |
+
second_conv_layer = exit_flow[-1]
|
120 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
|
121 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
|
122 |
+
|
123 |
+
# Residual blocks.
|
124 |
+
if hasattr(layer, "residual_projection"):
|
125 |
+
if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
|
126 |
+
residual = layer.residual_projection
|
127 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
|
128 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
|
129 |
+
|
130 |
+
# Timestep embedding.
|
131 |
+
embedding_proj = embedding_flow[-1]
|
132 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
|
133 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
|
134 |
+
|
135 |
+
# Norms.
|
136 |
+
first_group_norm = entry_flow[0]
|
137 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
|
138 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
|
139 |
+
second_group_norm = exit_flow[0]
|
140 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
|
141 |
+
unet_state_dict[f"down_blocks.{down_block_id}.resnets.{down_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
|
142 |
+
|
143 |
+
# Middle.
|
144 |
+
elif int(parts[-1]) == 8 or int(parts[-1]) == 9:
|
145 |
+
entry_flow = layer.entry_flow
|
146 |
+
embedding_flow = layer.embedding_flow
|
147 |
+
exit_flow = layer.exit_flow
|
148 |
+
|
149 |
+
mid_resnet_id = int(parts[-1]) % 2
|
150 |
+
|
151 |
+
# Conv blocks.
|
152 |
+
first_conv_layer = entry_flow[-1]
|
153 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
|
154 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
|
155 |
+
second_conv_layer = exit_flow[-1]
|
156 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
|
157 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
|
158 |
+
|
159 |
+
# Residual blocks.
|
160 |
+
if hasattr(layer, "residual_projection"):
|
161 |
+
if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
|
162 |
+
residual = layer.residual_projection
|
163 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
|
164 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
|
165 |
+
|
166 |
+
# Timestep embedding.
|
167 |
+
embedding_proj = embedding_flow[-1]
|
168 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
|
169 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
|
170 |
+
|
171 |
+
# Norms.
|
172 |
+
first_group_norm = entry_flow[0]
|
173 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
|
174 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
|
175 |
+
second_group_norm = exit_flow[0]
|
176 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
|
177 |
+
unet_state_dict[f"mid_block.resnets.{mid_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
|
178 |
+
|
179 |
+
# Up.
|
180 |
+
elif int(parts[-1]) > 9 and up_res_block_flag < len(up_res_blocks):
|
181 |
+
entry_flow = layer.entry_flow
|
182 |
+
embedding_flow = layer.embedding_flow
|
183 |
+
exit_flow = layer.exit_flow
|
184 |
+
|
185 |
+
up_res_block = up_res_blocks[up_res_block_flag]
|
186 |
+
up_block_id = up_res_block[0]
|
187 |
+
up_resnet_id = up_res_block[1]
|
188 |
+
|
189 |
+
# Conv blocks.
|
190 |
+
first_conv_layer = entry_flow[-1]
|
191 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.weight"] = torch.from_numpy(first_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
|
192 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv1.bias"] = torch.from_numpy(first_conv_layer.get_weights()[1])
|
193 |
+
second_conv_layer = exit_flow[-1]
|
194 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.weight"] = torch.from_numpy(second_conv_layer.get_weights()[0].transpose(3, 2, 0, 1))
|
195 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv2.bias"] = torch.from_numpy(second_conv_layer.get_weights()[1])
|
196 |
+
|
197 |
+
# Residual blocks.
|
198 |
+
if hasattr(layer, "residual_projection"):
|
199 |
+
if isinstance(layer.residual_projection, stable_diffusion.__internal__.layers.padded_conv2d.PaddedConv2D):
|
200 |
+
residual = layer.residual_projection
|
201 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.weight"] = torch.from_numpy(residual.get_weights()[0].transpose(3, 2, 0, 1))
|
202 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.conv_shortcut.bias"] = torch.from_numpy(residual.get_weights()[1])
|
203 |
+
|
204 |
+
# Timestep embedding.
|
205 |
+
embedding_proj = embedding_flow[-1]
|
206 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.weight"] = torch.from_numpy(embedding_proj.get_weights()[0].transpose())
|
207 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.time_emb_proj.bias"] = torch.from_numpy(embedding_proj.get_weights()[1])
|
208 |
+
|
209 |
+
# Norms.
|
210 |
+
first_group_norm = entry_flow[0]
|
211 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.weight"] = torch.from_numpy(first_group_norm.get_weights()[0])
|
212 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm1.bias"] = torch.from_numpy(first_group_norm.get_weights()[1])
|
213 |
+
second_group_norm = exit_flow[0]
|
214 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.weight"] = torch.from_numpy(second_group_norm.get_weights()[0])
|
215 |
+
unet_state_dict[f"up_blocks.{up_block_id}.resnets.{up_resnet_id}.norm2.bias"] = torch.from_numpy(second_group_norm.get_weights()[1])
|
216 |
+
|
217 |
+
up_res_block_flag += 1
|
218 |
+
|
219 |
+
# All SpatialTransformer blocks.
|
220 |
+
elif isinstance(layer, stable_diffusion.diffusion_model.SpatialTransformer):
|
221 |
+
layer_name = layer.name
|
222 |
+
parts = layer_name.split("_")
|
223 |
+
|
224 |
+
# Down.
|
225 |
+
if len(parts) == 2 or int(parts[-1]) < 6:
|
226 |
+
down_block_id = 0 if len(parts) == 2 else int(parts[-1]) // 2
|
227 |
+
down_attention_id = 0 if len(parts) == 2 else int(parts[-1]) % 2
|
228 |
+
|
229 |
+
# Convs.
|
230 |
+
proj1 = layer.proj1
|
231 |
+
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
|
232 |
+
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
|
233 |
+
proj2 = layer.proj2
|
234 |
+
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
|
235 |
+
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
|
236 |
+
|
237 |
+
# Transformer blocks.
|
238 |
+
transformer_block = layer.transformer_block
|
239 |
+
unet_state_dict.update(port_transformer_block(transformer_block, "down", down_block_id, down_attention_id))
|
240 |
+
|
241 |
+
# Norms.
|
242 |
+
norm = layer.norm
|
243 |
+
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
|
244 |
+
unet_state_dict[f"down_blocks.{down_block_id}.attentions.{down_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
|
245 |
+
|
246 |
+
# Middle.
|
247 |
+
elif int(parts[-1]) == 6:
|
248 |
+
mid_attention_id = int(parts[-1]) % 2
|
249 |
+
# Convs.
|
250 |
+
proj1 = layer.proj1
|
251 |
+
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
|
252 |
+
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
|
253 |
+
proj2 = layer.proj2
|
254 |
+
unet_state_dict[f"mid_block.attentions.{mid_resnet_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
|
255 |
+
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
|
256 |
+
|
257 |
+
# Transformer blocks.
|
258 |
+
transformer_block = layer.transformer_block
|
259 |
+
unet_state_dict.update(port_transformer_block(transformer_block, "mid", None, mid_attention_id))
|
260 |
+
|
261 |
+
# Norms.
|
262 |
+
norm = layer.norm
|
263 |
+
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
|
264 |
+
unet_state_dict[f"mid_block.attentions.{mid_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
|
265 |
+
|
266 |
+
# Up.
|
267 |
+
elif int(parts[-1]) > 6 and up_spatial_transformer_flag < len(up_spatial_transformer_blocks):
|
268 |
+
up_spatial_transformer_block = up_spatial_transformer_blocks[up_spatial_transformer_flag]
|
269 |
+
up_block_id = up_spatial_transformer_block[0]
|
270 |
+
up_attention_id = up_spatial_transformer_block[1]
|
271 |
+
|
272 |
+
# Convs.
|
273 |
+
proj1 = layer.proj1
|
274 |
+
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.weight"] = torch.from_numpy(proj1.get_weights()[0].transpose(3, 2, 0, 1))
|
275 |
+
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_in.bias"] = torch.from_numpy(proj1.get_weights()[1])
|
276 |
+
proj2 = layer.proj2
|
277 |
+
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.weight"] = torch.from_numpy(proj2.get_weights()[0].transpose(3, 2, 0, 1))
|
278 |
+
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.proj_out.bias"] = torch.from_numpy(proj2.get_weights()[1])
|
279 |
+
|
280 |
+
# Transformer blocks.
|
281 |
+
transformer_block = layer.transformer_block
|
282 |
+
unet_state_dict.update(port_transformer_block(transformer_block, "up", up_block_id, up_attention_id))
|
283 |
+
|
284 |
+
# Norms.
|
285 |
+
norm = layer.norm
|
286 |
+
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.weight"] = torch.from_numpy(norm.get_weights()[0])
|
287 |
+
unet_state_dict[f"up_blocks.{up_block_id}.attentions.{up_attention_id}.norm.bias"] = torch.from_numpy(norm.get_weights()[1])
|
288 |
+
|
289 |
+
up_spatial_transformer_flag += 1
|
290 |
+
|
291 |
+
return unet_state_dict
|
conversion_utils/utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
|
7 |
+
def run_assertion(orig_pt_state_dict: Dict[str, torch.Tensor], pt_state_dict_from_tf: Dict[str, torch.Tensor]):
|
8 |
+
for k in orig_pt_state_dict:
|
9 |
+
try:
|
10 |
+
np.testing.assert_allclose(
|
11 |
+
orig_pt_state_dict[k].numpy(),
|
12 |
+
pt_state_dict_from_tf[k].numpy()
|
13 |
+
)
|
14 |
+
except:
|
15 |
+
raise ValueError("There are problems in the parameter population process. Cannot proceed :(")
|
convert.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from conversion_utils import populate_text_encoder, populate_unet, run_assertion
|
2 |
+
|
3 |
+
from diffusers import (
|
4 |
+
AutoencoderKL,
|
5 |
+
StableDiffusionPipeline,
|
6 |
+
UNet2DConditionModel,
|
7 |
+
)
|
8 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
9 |
+
from transformers import CLIPTextModel
|
10 |
+
import keras_cv
|
11 |
+
import tensorflow as tf
|
12 |
+
|
13 |
+
|
14 |
+
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
15 |
+
REVISION = None
|
16 |
+
NON_EMA_REVISION = None
|
17 |
+
IMG_HEIGHT = IMG_WIDTH = 512
|
18 |
+
|
19 |
+
def initialize_pt_models():
|
20 |
+
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
|
21 |
+
their pre-trained weights."""
|
22 |
+
pt_text_encoder = CLIPTextModel.from_pretrained(
|
23 |
+
PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
|
24 |
+
)
|
25 |
+
pt_vae = AutoencoderKL.from_pretrained(
|
26 |
+
PRETRAINED_CKPT, subfolder="vae", revision=REVISION
|
27 |
+
)
|
28 |
+
pt_unet = UNet2DConditionModel.from_pretrained(
|
29 |
+
PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
|
30 |
+
)
|
31 |
+
pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
32 |
+
PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
|
33 |
+
)
|
34 |
+
|
35 |
+
return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
|
36 |
+
|
37 |
+
def initialize_tf_models():
|
38 |
+
"""Initializes the separate models of Stable Diffusion from KerasCV and downloads
|
39 |
+
their pre-trained weights."""
|
40 |
+
tf_sd_model = keras_cv.models.StableDiffusion(img_height=IMG_HEIGHT, img_width=IMG_WIDTH)
|
41 |
+
_ = tf_sd_model.text_to_image("Cartoon") # To download the weights.
|
42 |
+
|
43 |
+
tf_text_encoder = tf_sd_model.text_encoder
|
44 |
+
tf_vae = tf_sd_model.image_encoder
|
45 |
+
tf_unet = tf_sd_model.diffusion_model
|
46 |
+
return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
|
47 |
+
|
48 |
+
|
49 |
+
def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
|
50 |
+
pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
|
51 |
+
tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models()
|
52 |
+
print("Pre-trained model weights downloaded.")
|
53 |
+
|
54 |
+
if text_encoder_weights is not None:
|
55 |
+
print("Loading fine-tuned text encoder weights.")
|
56 |
+
text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
|
57 |
+
tf_text_encoder.load_weights(text_encoder_weights_path)
|
58 |
+
if unet_weights is not None:
|
59 |
+
print("Loading fine-tuned UNet weights.")
|
60 |
+
unet_weights_path = tf.keras.utils.get_file(unet_weights)
|
61 |
+
tf_unet.load_weights(unet_weights_path)
|
62 |
+
|
63 |
+
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
|
64 |
+
unet_state_dict_from_tf = populate_unet(tf_unet)
|
65 |
+
print("Conversion done, now running assertions...")
|
66 |
+
|
67 |
+
# Since we cannot compare the fine-tuned weights.
|
68 |
+
if text_encoder_weights is None:
|
69 |
+
text_encoder_state_dict_from_pt = pt_text_encoder.state_dict()
|
70 |
+
run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf)
|
71 |
+
if unet_weights is None:
|
72 |
+
unet_state_dict_from_pt = pt_text_encoder.state_dict()
|
73 |
+
run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
|
74 |
+
|
75 |
+
print("Assertions successful, populating the converted parameters into the diffusers models...")
|
76 |
+
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
|
77 |
+
pt_unet.load_state_dict(unet_state_dict_from_tf)
|
78 |
+
|
79 |
+
print("Parameters ported, preparing StabelDiffusionPipeline...")
|
80 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
81 |
+
PRETRAINED_CKPT,
|
82 |
+
unet=pt_unet,
|
83 |
+
text_encoder=pt_text_encoder,
|
84 |
+
vae=pt_vae,
|
85 |
+
safety_checker=pt_safety_checker,
|
86 |
+
revision=None,
|
87 |
+
)
|
88 |
+
return pipeline
|
89 |
+
|
90 |
+
|
hub_utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .readme import save_model_card
|
2 |
+
from .repo import push_to_hub
|
hub_utils/readme.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
# Copied from https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/text_to_image/train_text_to_image_lora.py#L55
|
5 |
+
def save_model_card(base_model=str, repo_folder=None, weight_paths=None):
|
6 |
+
yaml = f"""
|
7 |
+
---
|
8 |
+
license: creativeml-openrail-m
|
9 |
+
base_model: {base_model}
|
10 |
+
tags:
|
11 |
+
- stable-diffusion
|
12 |
+
- stable-diffusion-diffusers
|
13 |
+
- text-to-image
|
14 |
+
- diffusers
|
15 |
+
inference: true
|
16 |
+
---
|
17 |
+
"""
|
18 |
+
model_card = f"""
|
19 |
+
# KerasCV Stable Diffusion in Diffusers 🧨🤗
|
20 |
+
|
21 |
+
The pipeline contained in this repository was created using [this Space](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers). The purpose is to convert the KerasCV Stable Diffusion weights in a way that is compatible with Diffusers. This allows users to fine-tune using KerasCV and use the fine-tuned weights in Diffusers taking advantage of its nifty features (like schedulers, fast attention, etc.).\n
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
if weight_paths is not None:
|
26 |
+
model_card += "Following weight paths (KerasCV) were used: {weight_paths}"
|
27 |
+
|
28 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
29 |
+
f.write(yaml + model_card)
|
hub_utils/repo.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi, create_repo
|
2 |
+
|
3 |
+
def push_to_hub(hf_token: str, push_dir: str, repo_prefix: None) -> str:
|
4 |
+
try:
|
5 |
+
if hf_token == "":
|
6 |
+
return "No HF token provided. Model won't be pushed."
|
7 |
+
else:
|
8 |
+
hf_api = HfApi(token=hf_token)
|
9 |
+
user = hf_api.whoami()["name"]
|
10 |
+
repo_id = f"{user}/{push_dir}" if repo_prefix == "" else f"{user}/{repo_prefix}-{push_dir}"
|
11 |
+
_ = create_repo(repo_id=repo_id, token=hf_token)
|
12 |
+
url = hf_api.upload_folder(folder_path=push_dir, repo_id=repo_id, exist_ok=True)
|
13 |
+
return f"Model successfully pushed: [{url}]({url})"
|
14 |
+
except Exception as e:
|
15 |
+
return f"{e}"
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.25.1
|
2 |
+
numpy==1.21.6
|
3 |
+
torch==1.12.1
|
4 |
+
tensorflow==2.10.0
|
5 |
+
git+https://github.com/keras-team/keras-cv.git@master
|
6 |
+
git+https://github.com/huggingface/diffusers.git@main
|
7 |
+
tensorflow-datasets==4.8.0
|