Spaces:
Running
on
T4
Running
on
T4
feat: requirements, cleanup
Browse files- Dockerfile +22 -0
- README.md +6 -6
- app.py +77 -28
- modules/lora.py +181 -0
- modules/model.py +0 -144
- requirements.txt +0 -8
Dockerfile
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dockerfile Public T4
|
2 |
+
|
3 |
+
FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
|
4 |
+
ENV DEBIAN_FRONTEND noninteractive
|
5 |
+
|
6 |
+
WORKDIR /content
|
7 |
+
|
8 |
+
RUN apt-get update -y && apt-get upgrade -y && apt-get install -y libgl1 libglib2.0-0 wget git git-lfs python3-pip python-is-python3 && pip3 install --upgrade pip
|
9 |
+
|
10 |
+
RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchsde --extra-index-url https://download.pytorch.org/whl/cu113
|
11 |
+
RUN pip install https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.16/xformers-0.0.16+814314d.d20230118-cp310-cp310-linux_x86_64.whl
|
12 |
+
RUN pip install --pre triton
|
13 |
+
RUN pip install numexpr einops diffusers transformers k_diffusion safetensors gradio
|
14 |
+
|
15 |
+
ADD . .
|
16 |
+
RUN adduser --disabled-password --gecos '' user
|
17 |
+
RUN chown -R user:user /content
|
18 |
+
RUN chmod -R 777 /content
|
19 |
+
USER user
|
20 |
+
|
21 |
+
EXPOSE 7860
|
22 |
+
CMD python /content/app.py
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
title: Sd Diffusers Webui
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
sdk_version: 3.
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
license: openrail
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Sd Diffusers Webui
|
3 |
+
emoji: 🐳
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
+
sdk: docker
|
7 |
+
sdk_version: 3.9
|
|
|
8 |
pinned: false
|
9 |
license: openrail
|
10 |
+
app_port: 7860
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -4,6 +4,7 @@ import time
|
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
7 |
|
8 |
from gradio import inputs
|
9 |
from diffusers import (
|
@@ -14,7 +15,6 @@ from diffusers import (
|
|
14 |
from modules.model import (
|
15 |
CrossAttnProcessor,
|
16 |
StableDiffusionPipeline,
|
17 |
-
load_lora_attn_procs,
|
18 |
)
|
19 |
from torchvision import transforms
|
20 |
from transformers import CLIPTokenizer, CLIPTextModel
|
@@ -22,16 +22,17 @@ from PIL import Image
|
|
22 |
from pathlib import Path
|
23 |
from safetensors.torch import load_file
|
24 |
import modules.safe as _
|
|
|
25 |
|
26 |
models = [
|
27 |
-
|
28 |
-
("
|
29 |
-
("
|
30 |
-
("
|
|
|
31 |
]
|
32 |
|
33 |
-
base_name, base_model = models[0]
|
34 |
-
clip_skip = 2
|
35 |
|
36 |
samplers_k_diffusion = [
|
37 |
("Euler a", "sample_euler_ancestral", {}),
|
@@ -103,6 +104,10 @@ unet_cache = {
|
|
103 |
base_name: unet
|
104 |
}
|
105 |
|
|
|
|
|
|
|
|
|
106 |
def get_model(name):
|
107 |
keys = [k[0] for k in models]
|
108 |
if name not in unet_cache:
|
@@ -114,11 +119,21 @@ def get_model(name):
|
|
114 |
subfolder="unet",
|
115 |
torch_dtype=torch.float16,
|
116 |
)
|
|
|
|
|
|
|
117 |
unet_cache[name] = unet
|
|
|
118 |
|
119 |
g_unet = unet_cache[name]
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
def error_str(error, title="Error"):
|
124 |
return (
|
@@ -129,18 +144,46 @@ def error_str(error, title="Error"):
|
|
129 |
)
|
130 |
|
131 |
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
|
135 |
def restore_all():
|
136 |
global te_base_weight, tokenizer
|
137 |
-
|
|
|
|
|
|
|
138 |
tokenizer = CLIPTokenizer.from_pretrained(
|
139 |
base_model,
|
140 |
subfolder="tokenizer",
|
141 |
torch_dtype=torch.float16,
|
142 |
)
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
def inference(
|
146 |
prompt,
|
@@ -167,13 +210,15 @@ def inference(
|
|
167 |
global pipe, unet, tokenizer, text_encoder
|
168 |
if seed is None or seed == 0:
|
169 |
seed = random.randint(0, 2147483647)
|
|
|
|
|
|
|
170 |
generator = torch.Generator("cuda").manual_seed(int(seed))
|
171 |
|
172 |
-
local_unet = get_model(model)
|
173 |
if lora_state is not None and lora_state != "":
|
174 |
-
|
175 |
-
|
176 |
-
local_unet.set_attn_processor(CrossAttnProcessor())
|
177 |
|
178 |
pipe.setup_unet(local_unet)
|
179 |
sampler_name, sampler_opt = None, None
|
@@ -182,23 +227,23 @@ def inference(
|
|
182 |
sampler_name, sampler_opt = funcname, options
|
183 |
|
184 |
if embs is not None and len(embs) > 0:
|
185 |
-
|
186 |
for name, file in embs.items():
|
187 |
if str(file).endswith(".pt"):
|
188 |
loaded_learned_embeds = torch.load(file, map_location="cpu")
|
189 |
else:
|
190 |
loaded_learned_embeds = load_file(file, device="cpu")
|
191 |
loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"]
|
192 |
-
|
193 |
|
194 |
-
|
195 |
-
|
|
|
|
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
-delta_weight.shape[0] :
|
201 |
-
] = delta_weight
|
202 |
|
203 |
config = {
|
204 |
"negative_prompt": neg_prompt,
|
@@ -234,6 +279,10 @@ def inference(
|
|
234 |
# restore
|
235 |
if embs is not None and len(embs) > 0:
|
236 |
restore_all()
|
|
|
|
|
|
|
|
|
237 |
return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
|
238 |
|
239 |
|
@@ -513,7 +562,7 @@ with gr.Blocks(css=css) as demo:
|
|
513 |
label="Guidance scale", value=7.5, maximum=15
|
514 |
)
|
515 |
steps = gr.Slider(
|
516 |
-
label="Steps", value=25, minimum=2, maximum=
|
517 |
)
|
518 |
|
519 |
with gr.Row():
|
@@ -704,7 +753,7 @@ with gr.Blocks(css=css) as demo:
|
|
704 |
step=0.01,
|
705 |
value=0.5,
|
706 |
)
|
707 |
-
|
708 |
|
709 |
sk_update.click(
|
710 |
detect_text,
|
@@ -739,7 +788,7 @@ with gr.Blocks(css=css) as demo:
|
|
739 |
source="upload",
|
740 |
shape=(512, 512),
|
741 |
)
|
742 |
-
|
743 |
mask_outsides2 = gr.Checkbox(
|
744 |
label="Mask other areas",
|
745 |
value=False
|
@@ -803,4 +852,4 @@ with gr.Blocks(css=css) as demo:
|
|
803 |
|
804 |
print(f"Space built in {time.time() - start_time:.2f} seconds")
|
805 |
# demo.launch(share=True)
|
806 |
-
demo.launch()
|
|
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
import math
|
8 |
|
9 |
from gradio import inputs
|
10 |
from diffusers import (
|
|
|
15 |
from modules.model import (
|
16 |
CrossAttnProcessor,
|
17 |
StableDiffusionPipeline,
|
|
|
18 |
)
|
19 |
from torchvision import transforms
|
20 |
from transformers import CLIPTokenizer, CLIPTextModel
|
|
|
22 |
from pathlib import Path
|
23 |
from safetensors.torch import load_file
|
24 |
import modules.safe as _
|
25 |
+
from modules.lora import LoRANetwork
|
26 |
|
27 |
models = [
|
28 |
+
# format: name, model_path, clip_skip
|
29 |
+
("AbyssOrangeMix2", "Korakoe/AbyssOrangeMix2-HF", 2),
|
30 |
+
("Basil Mix", "nuigurumi/basil_mix", 2),
|
31 |
+
("Pastal Mix", "andite/pastel-mix", 2),
|
32 |
+
("ACertainModel", "JosephusCheung/ACertainModel", 2),
|
33 |
]
|
34 |
|
35 |
+
base_name, base_model, clip_skip = models[0]
|
|
|
36 |
|
37 |
samplers_k_diffusion = [
|
38 |
("Euler a", "sample_euler_ancestral", {}),
|
|
|
104 |
base_name: unet
|
105 |
}
|
106 |
|
107 |
+
lora_cache = {
|
108 |
+
base_name: LoRANetwork(text_encoder, unet)
|
109 |
+
}
|
110 |
+
|
111 |
def get_model(name):
|
112 |
keys = [k[0] for k in models]
|
113 |
if name not in unet_cache:
|
|
|
119 |
subfolder="unet",
|
120 |
torch_dtype=torch.float16,
|
121 |
)
|
122 |
+
if torch.cuda.is_available():
|
123 |
+
unet.to("cuda")
|
124 |
+
|
125 |
unet_cache[name] = unet
|
126 |
+
lora_cache[name] = LoRANetwork(lora_cache[base_name].text_encoder_loras, unet)
|
127 |
|
128 |
g_unet = unet_cache[name]
|
129 |
+
g_lora = lora_cache[name]
|
130 |
+
g_unet.set_attn_processor(CrossAttnProcessor())
|
131 |
+
g_lora.reset()
|
132 |
+
return g_unet, g_lora
|
133 |
+
|
134 |
+
# precache on huggingface
|
135 |
+
# for model in get_model_list():
|
136 |
+
# get_model(model[0])
|
137 |
|
138 |
def error_str(error, title="Error"):
|
139 |
return (
|
|
|
144 |
)
|
145 |
|
146 |
|
147 |
+
te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
|
148 |
+
original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
|
149 |
+
|
150 |
+
def make_token_names(embs):
|
151 |
+
all_tokens = []
|
152 |
+
for name, vec in embs.items():
|
153 |
+
tokens = [f'emb-{name}-{i}' for i in range(len(vec))]
|
154 |
+
all_tokens.append(tokens)
|
155 |
+
return all_tokens
|
156 |
+
|
157 |
+
def setup_tokenizer(embs):
|
158 |
+
reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()]
|
159 |
+
clip_keywords = [' '.join(s) for s in make_token_names(embs)]
|
160 |
+
|
161 |
+
def parse_prompt(prompt: str):
|
162 |
+
for m, v in zip(reg_match, clip_keywords):
|
163 |
+
prompt = m.sub(v, prompt)
|
164 |
+
return prompt
|
165 |
|
166 |
|
167 |
def restore_all():
|
168 |
global te_base_weight, tokenizer
|
169 |
+
tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
|
170 |
+
|
171 |
+
embeddings = text_encoder.get_input_embeddings()
|
172 |
+
text_encoder.get_input_embeddings().weight.data = embeddings.weight.data[:te_base_weight_length]
|
173 |
tokenizer = CLIPTokenizer.from_pretrained(
|
174 |
base_model,
|
175 |
subfolder="tokenizer",
|
176 |
torch_dtype=torch.float16,
|
177 |
)
|
178 |
|
179 |
+
def convert_size(size_bytes):
|
180 |
+
if size_bytes == 0:
|
181 |
+
return "0B"
|
182 |
+
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
|
183 |
+
i = int(math.floor(math.log(size_bytes, 1024)))
|
184 |
+
p = math.pow(1024, i)
|
185 |
+
s = round(size_bytes / p, 2)
|
186 |
+
return "%s %s" % (s, size_name[i])
|
187 |
|
188 |
def inference(
|
189 |
prompt,
|
|
|
210 |
global pipe, unet, tokenizer, text_encoder
|
211 |
if seed is None or seed == 0:
|
212 |
seed = random.randint(0, 2147483647)
|
213 |
+
|
214 |
+
start_time = time.time()
|
215 |
+
restore_all()
|
216 |
generator = torch.Generator("cuda").manual_seed(int(seed))
|
217 |
|
218 |
+
local_unet, local_lora = get_model(model)
|
219 |
if lora_state is not None and lora_state != "":
|
220 |
+
local_lora.load(lora_state, lora_scale)
|
221 |
+
local_lora.to(local_unet.device, dtype=local_unet.dtype)
|
|
|
222 |
|
223 |
pipe.setup_unet(local_unet)
|
224 |
sampler_name, sampler_opt = None, None
|
|
|
227 |
sampler_name, sampler_opt = funcname, options
|
228 |
|
229 |
if embs is not None and len(embs) > 0:
|
230 |
+
ti_embs = {}
|
231 |
for name, file in embs.items():
|
232 |
if str(file).endswith(".pt"):
|
233 |
loaded_learned_embeds = torch.load(file, map_location="cpu")
|
234 |
else:
|
235 |
loaded_learned_embeds = load_file(file, device="cpu")
|
236 |
loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"]
|
237 |
+
ti_embs[name] = loaded_learned_embeds
|
238 |
|
239 |
+
if len(ti_embs) > 0:
|
240 |
+
tokens = setup_tokenizer(ti_embs)
|
241 |
+
added_tokens = tokenizer.add_tokens(tokens)
|
242 |
+
delta_weight = torch.cat([val for val in ti_embs.values()], dim=0)
|
243 |
|
244 |
+
assert added_tokens == delta_weight.shape[0]
|
245 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
246 |
+
text_encoder.get_input_embeddings().weight.data[-delta_weight.shape[0]:] = delta_weight
|
|
|
|
|
247 |
|
248 |
config = {
|
249 |
"negative_prompt": neg_prompt,
|
|
|
279 |
# restore
|
280 |
if embs is not None and len(embs) > 0:
|
281 |
restore_all()
|
282 |
+
|
283 |
+
end_time = time.time()
|
284 |
+
vram_free, vram_total = torch.cuda.mem_get_info()
|
285 |
+
print(f"done: res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")
|
286 |
return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
|
287 |
|
288 |
|
|
|
562 |
label="Guidance scale", value=7.5, maximum=15
|
563 |
)
|
564 |
steps = gr.Slider(
|
565 |
+
label="Steps", value=25, minimum=2, maximum=50, step=1
|
566 |
)
|
567 |
|
568 |
with gr.Row():
|
|
|
753 |
step=0.01,
|
754 |
value=0.5,
|
755 |
)
|
756 |
+
|
757 |
|
758 |
sk_update.click(
|
759 |
detect_text,
|
|
|
788 |
source="upload",
|
789 |
shape=(512, 512),
|
790 |
)
|
791 |
+
|
792 |
mask_outsides2 = gr.Checkbox(
|
793 |
label="Mask other areas",
|
794 |
value=False
|
|
|
852 |
|
853 |
print(f"Space built in {time.time() - start_time:.2f} seconds")
|
854 |
# demo.launch(share=True)
|
855 |
+
demo.launch(enable_queue=True, server_name="0.0.0.0", server_port=7860)
|
modules/lora.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
# https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import modules.safe as _
|
11 |
+
from safetensors.torch import load_file
|
12 |
+
|
13 |
+
|
14 |
+
class LoRAModule(torch.nn.Module):
|
15 |
+
"""
|
16 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
lora_name,
|
22 |
+
org_module: torch.nn.Module,
|
23 |
+
multiplier=1.0,
|
24 |
+
lora_dim=4,
|
25 |
+
alpha=1,
|
26 |
+
):
|
27 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
28 |
+
super().__init__()
|
29 |
+
self.lora_name = lora_name
|
30 |
+
self.lora_dim = lora_dim
|
31 |
+
|
32 |
+
if org_module.__class__.__name__ == "Conv2d":
|
33 |
+
in_dim = org_module.in_channels
|
34 |
+
out_dim = org_module.out_channels
|
35 |
+
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
36 |
+
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
37 |
+
else:
|
38 |
+
in_dim = org_module.in_features
|
39 |
+
out_dim = org_module.out_features
|
40 |
+
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
41 |
+
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
42 |
+
|
43 |
+
if type(alpha) == torch.Tensor:
|
44 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
45 |
+
|
46 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
47 |
+
self.scale = alpha / self.lora_dim
|
48 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
49 |
+
|
50 |
+
# same as microsoft's
|
51 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
52 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
53 |
+
|
54 |
+
self.multiplier = multiplier
|
55 |
+
self.org_module = org_module # remove in applying
|
56 |
+
self.enable = False
|
57 |
+
|
58 |
+
def resize(self, rank, alpha):
|
59 |
+
self.alpha = torch.tensor(alpha)
|
60 |
+
self.scale = alpha / rank
|
61 |
+
if self.lora_down.__class__.__name__ == "Conv2d":
|
62 |
+
in_dim = self.lora_down.in_channels
|
63 |
+
out_dim = self.lora_up.out_channels
|
64 |
+
self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
|
65 |
+
self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
|
66 |
+
else:
|
67 |
+
in_dim = self.lora_down.in_features
|
68 |
+
out_dim = self.lora_up.out_features
|
69 |
+
self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
|
70 |
+
self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
|
71 |
+
|
72 |
+
def apply(self):
|
73 |
+
if hasattr(self, "org_module"):
|
74 |
+
self.org_forward = self.org_module.forward
|
75 |
+
self.org_module.forward = self.forward
|
76 |
+
del self.org_module
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
if self.enable:
|
80 |
+
return (
|
81 |
+
self.org_forward(x)
|
82 |
+
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
83 |
+
)
|
84 |
+
return self.org_forward(x)
|
85 |
+
|
86 |
+
|
87 |
+
class LoRANetwork(torch.nn.Module):
|
88 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
89 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
90 |
+
LORA_PREFIX_UNET = "lora_unet"
|
91 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
92 |
+
|
93 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
94 |
+
super().__init__()
|
95 |
+
self.multiplier = multiplier
|
96 |
+
self.lora_dim = lora_dim
|
97 |
+
self.alpha = alpha
|
98 |
+
|
99 |
+
# create module instances
|
100 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
|
101 |
+
loras = []
|
102 |
+
for name, module in root_module.named_modules():
|
103 |
+
if module.__class__.__name__ in target_replace_modules:
|
104 |
+
for child_name, child_module in module.named_modules():
|
105 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
106 |
+
lora_name = prefix + "." + name + "." + child_name
|
107 |
+
lora_name = lora_name.replace(".", "_")
|
108 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
|
109 |
+
loras.append(lora)
|
110 |
+
return loras
|
111 |
+
|
112 |
+
if isinstance(text_encoder, list):
|
113 |
+
self.text_encoder_loras = text_encoder
|
114 |
+
else:
|
115 |
+
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
116 |
+
print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
117 |
+
|
118 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
119 |
+
print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
120 |
+
|
121 |
+
self.weights_sd = None
|
122 |
+
|
123 |
+
# assertion
|
124 |
+
names = set()
|
125 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
126 |
+
assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
|
127 |
+
names.add(lora.lora_name)
|
128 |
+
|
129 |
+
lora.apply()
|
130 |
+
self.add_module(lora.lora_name, lora)
|
131 |
+
|
132 |
+
def reset(self):
|
133 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
134 |
+
lora.enable = False
|
135 |
+
|
136 |
+
def load(self, file, scale):
|
137 |
+
|
138 |
+
weights = None
|
139 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
140 |
+
weights = load_file(file)
|
141 |
+
else:
|
142 |
+
weights = torch.load(file, map_location="cpu")
|
143 |
+
|
144 |
+
if not weights:
|
145 |
+
return
|
146 |
+
|
147 |
+
network_alpha = None
|
148 |
+
network_dim = None
|
149 |
+
for key, value in weights.items():
|
150 |
+
if network_alpha is None and "alpha" in key:
|
151 |
+
network_alpha = value
|
152 |
+
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
|
153 |
+
network_dim = value.size()[0]
|
154 |
+
|
155 |
+
if network_alpha is None:
|
156 |
+
network_alpha = network_dim
|
157 |
+
|
158 |
+
weights_has_text_encoder = weights_has_unet = False
|
159 |
+
weights_to_modify = []
|
160 |
+
|
161 |
+
for key in weights.keys():
|
162 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
163 |
+
weights_has_text_encoder = True
|
164 |
+
|
165 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
166 |
+
weights_has_unet = True
|
167 |
+
|
168 |
+
if weights_has_text_encoder:
|
169 |
+
weights_to_modify += self.text_encoder_loras
|
170 |
+
|
171 |
+
if weights_has_unet:
|
172 |
+
weights_to_modify += self.unet_loras
|
173 |
+
|
174 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
175 |
+
lora.resize(network_dim, network_alpha)
|
176 |
+
if lora in weights_to_modify:
|
177 |
+
lora.enable = True
|
178 |
+
|
179 |
+
info = self.load_state_dict(weights, False)
|
180 |
+
print(f"Weights are loaded. Unexpect keys={info.unexpected_keys}")
|
181 |
+
|
modules/model.py
CHANGED
@@ -68,79 +68,6 @@ def get_attention_scores(attn, query, key, attention_mask=None):
|
|
68 |
return attention_scores
|
69 |
|
70 |
|
71 |
-
def load_lora_attn_procs(model_file, unet, scale=1.0):
|
72 |
-
|
73 |
-
if Path(model_file).suffix == ".pt":
|
74 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
75 |
-
else:
|
76 |
-
state_dict = load_file(model_file, device="cpu")
|
77 |
-
|
78 |
-
if any("lora_unet_down_blocks" in k for k in state_dict.keys()):
|
79 |
-
# convert ldm format lora
|
80 |
-
df_lora = {}
|
81 |
-
attn_numlayer = re.compile(r"_attn(\d)_to_([qkv]|out).lora_")
|
82 |
-
alpha_numlayer = re.compile(r"_attn(\d)_to_([qkv]|out).alpha")
|
83 |
-
for k, v in state_dict.items():
|
84 |
-
if "attn" not in k or "lora_te" in k:
|
85 |
-
# currently not support: ff, clip-attn
|
86 |
-
continue
|
87 |
-
k = k.replace("lora_unet_down_blocks_", "down_blocks.")
|
88 |
-
k = k.replace("lora_unet_up_blocks_", "up_blocks.")
|
89 |
-
k = k.replace("lora_unet_mid_block_", "mid_block_")
|
90 |
-
k = k.replace("_attentions_", ".attentions.")
|
91 |
-
k = k.replace("_transformer_blocks_", ".transformer_blocks.")
|
92 |
-
k = k.replace("to_out_0", "to_out")
|
93 |
-
k = attn_numlayer.sub(r".attn\1.processor.to_\2_lora.", k)
|
94 |
-
k = alpha_numlayer.sub(r".attn\1.processor.to_\2_lora.alpha", k)
|
95 |
-
df_lora[k] = v
|
96 |
-
state_dict = df_lora
|
97 |
-
|
98 |
-
# fill attn processors
|
99 |
-
attn_processors = {}
|
100 |
-
|
101 |
-
is_lora = all("lora" in k for k in state_dict.keys())
|
102 |
-
|
103 |
-
if is_lora:
|
104 |
-
lora_grouped_dict = defaultdict(dict)
|
105 |
-
for key, value in state_dict.items():
|
106 |
-
if "alpha" in key:
|
107 |
-
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(
|
108 |
-
key.split(".")[-2:]
|
109 |
-
)
|
110 |
-
else:
|
111 |
-
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(
|
112 |
-
key.split(".")[-3:]
|
113 |
-
)
|
114 |
-
lora_grouped_dict[attn_processor_key][sub_key] = value
|
115 |
-
|
116 |
-
for key, value_dict in lora_grouped_dict.items():
|
117 |
-
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
118 |
-
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
119 |
-
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
120 |
-
|
121 |
-
attn_processors[key] = LoRACrossAttnProcessor(
|
122 |
-
hidden_size=hidden_size,
|
123 |
-
cross_attention_dim=cross_attention_dim,
|
124 |
-
rank=rank,
|
125 |
-
scale=scale,
|
126 |
-
)
|
127 |
-
attn_processors[key].load_state_dict(value_dict, strict=False)
|
128 |
-
|
129 |
-
else:
|
130 |
-
raise ValueError(
|
131 |
-
f"{model_file} does not seem to be in the correct format expected by LoRA training."
|
132 |
-
)
|
133 |
-
|
134 |
-
# set correct dtype & device
|
135 |
-
attn_processors = {
|
136 |
-
k: v.to(device=unet.device, dtype=unet.dtype)
|
137 |
-
for k, v in attn_processors.items()
|
138 |
-
}
|
139 |
-
|
140 |
-
# set layers
|
141 |
-
unet.set_attn_processor(attn_processors)
|
142 |
-
|
143 |
-
|
144 |
class CrossAttnProcessor(nn.Module):
|
145 |
def __call__(
|
146 |
self,
|
@@ -148,7 +75,6 @@ class CrossAttnProcessor(nn.Module):
|
|
148 |
hidden_states,
|
149 |
encoder_hidden_states=None,
|
150 |
attention_mask=None,
|
151 |
-
qkvo_bias=None,
|
152 |
):
|
153 |
batch_size, sequence_length, _ = hidden_states.shape
|
154 |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
@@ -166,11 +92,6 @@ class CrossAttnProcessor(nn.Module):
|
|
166 |
key = attn.to_k(encoder_states)
|
167 |
value = attn.to_v(encoder_states)
|
168 |
|
169 |
-
if qkvo_bias is not None:
|
170 |
-
query += qkvo_bias["q"](hidden_states)
|
171 |
-
key += qkvo_bias["k"](encoder_states)
|
172 |
-
value += qkvo_bias["v"](encoder_states)
|
173 |
-
|
174 |
query = attn.head_to_batch_dim(query)
|
175 |
key = attn.head_to_batch_dim(key)
|
176 |
value = attn.head_to_batch_dim(value)
|
@@ -219,76 +140,11 @@ class CrossAttnProcessor(nn.Module):
|
|
219 |
# linear proj
|
220 |
hidden_states = attn.to_out[0](hidden_states)
|
221 |
|
222 |
-
if qkvo_bias is not None:
|
223 |
-
hidden_states += qkvo_bias["o"](hidden_states)
|
224 |
-
|
225 |
# dropout
|
226 |
hidden_states = attn.to_out[1](hidden_states)
|
227 |
|
228 |
return hidden_states
|
229 |
|
230 |
-
|
231 |
-
class LoRACrossAttnProcessor(CrossAttnProcessor):
|
232 |
-
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, scale=1.0):
|
233 |
-
super().__init__()
|
234 |
-
|
235 |
-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
236 |
-
self.to_k_lora = LoRALinearLayer(
|
237 |
-
cross_attention_dim or hidden_size, hidden_size, rank
|
238 |
-
)
|
239 |
-
self.to_v_lora = LoRALinearLayer(
|
240 |
-
cross_attention_dim or hidden_size, hidden_size, rank
|
241 |
-
)
|
242 |
-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
243 |
-
self.scale = scale
|
244 |
-
|
245 |
-
def __call__(
|
246 |
-
self,
|
247 |
-
attn,
|
248 |
-
hidden_states,
|
249 |
-
encoder_hidden_states=None,
|
250 |
-
attention_mask=None,
|
251 |
-
):
|
252 |
-
scale = self.scale
|
253 |
-
qkvo_bias = {
|
254 |
-
"q": lambda inputs: scale * self.to_q_lora(inputs),
|
255 |
-
"k": lambda inputs: scale * self.to_k_lora(inputs),
|
256 |
-
"v": lambda inputs: scale * self.to_v_lora(inputs),
|
257 |
-
"o": lambda inputs: scale * self.to_out_lora(inputs),
|
258 |
-
}
|
259 |
-
return super().__call__(
|
260 |
-
attn, hidden_states, encoder_hidden_states, attention_mask, qkvo_bias
|
261 |
-
)
|
262 |
-
|
263 |
-
|
264 |
-
class LoRALinearLayer(nn.Module):
|
265 |
-
def __init__(self, in_features, out_features, rank=4):
|
266 |
-
super().__init__()
|
267 |
-
|
268 |
-
if rank > min(in_features, out_features):
|
269 |
-
raise ValueError(
|
270 |
-
f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
|
271 |
-
)
|
272 |
-
|
273 |
-
self.down = nn.Linear(in_features, rank, bias=False)
|
274 |
-
self.up = nn.Linear(rank, out_features, bias=False)
|
275 |
-
self.scale = 1.0
|
276 |
-
self.alpha = rank
|
277 |
-
|
278 |
-
nn.init.normal_(self.down.weight, std=1 / rank)
|
279 |
-
nn.init.zeros_(self.up.weight)
|
280 |
-
|
281 |
-
def forward(self, hidden_states):
|
282 |
-
orig_dtype = hidden_states.dtype
|
283 |
-
dtype = self.down.weight.dtype
|
284 |
-
rank = self.down.out_features
|
285 |
-
|
286 |
-
down_hidden_states = self.down(hidden_states.to(dtype))
|
287 |
-
up_hidden_states = self.up(down_hidden_states) * (self.alpha / rank)
|
288 |
-
|
289 |
-
return up_hidden_states.to(orig_dtype)
|
290 |
-
|
291 |
-
|
292 |
class ModelWrapper:
|
293 |
def __init__(self, model, alphas_cumprod):
|
294 |
self.model = model
|
|
|
68 |
return attention_scores
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
class CrossAttnProcessor(nn.Module):
|
72 |
def __call__(
|
73 |
self,
|
|
|
75 |
hidden_states,
|
76 |
encoder_hidden_states=None,
|
77 |
attention_mask=None,
|
|
|
78 |
):
|
79 |
batch_size, sequence_length, _ = hidden_states.shape
|
80 |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
|
92 |
key = attn.to_k(encoder_states)
|
93 |
value = attn.to_v(encoder_states)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
95 |
query = attn.head_to_batch_dim(query)
|
96 |
key = attn.head_to_batch_dim(key)
|
97 |
value = attn.head_to_batch_dim(value)
|
|
|
140 |
# linear proj
|
141 |
hidden_states = attn.to_out[0](hidden_states)
|
142 |
|
|
|
|
|
|
|
143 |
# dropout
|
144 |
hidden_states = attn.to_out[1](hidden_states)
|
145 |
|
146 |
return hidden_states
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
class ModelWrapper:
|
149 |
def __init__(self, model, alphas_cumprod):
|
150 |
self.model = model
|
requirements.txt
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
torch
|
2 |
-
einops
|
3 |
-
diffusers
|
4 |
-
transformers
|
5 |
-
k_diffusion
|
6 |
-
safetensors
|
7 |
-
gradio
|
8 |
-
torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|