Spaces:
Running
on
Zero
Running
on
Zero
first commit
Browse files- .gitignore +1 -0
- README.md +8 -5
- app.py +873 -0
- examples/prompt_background.txt +8 -0
- examples/prompt_background_advanced.txt +0 -0
- examples/prompt_boy.txt +15 -0
- examples/prompt_girl.txt +16 -0
- examples/prompt_props.txt +43 -0
- model.py +1410 -0
- prompt_util.py +154 -0
- requirements.txt +14 -0
- share_btn.py +59 -0
- util.py +315 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.ipynb_checkpoints/*
|
README.md
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: red
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: SemanticPalette X Animagine XL 3.1
|
3 |
+
emoji: π₯π§ π¨π₯
|
4 |
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.21.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
+
models:
|
12 |
+
- cagliostrolab/animagine-xl-3.1
|
13 |
+
- ByteDance/SDXL-Lightning
|
14 |
---
|
15 |
|
16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,873 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import sys
|
22 |
+
|
23 |
+
sys.path.append('../../src')
|
24 |
+
|
25 |
+
import argparse
|
26 |
+
import random
|
27 |
+
import time
|
28 |
+
import json
|
29 |
+
import os
|
30 |
+
import glob
|
31 |
+
import pathlib
|
32 |
+
from functools import partial
|
33 |
+
from pprint import pprint
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
from PIL import Image
|
37 |
+
import torch
|
38 |
+
|
39 |
+
import gradio as gr
|
40 |
+
from huggingface_hub import snapshot_download
|
41 |
+
|
42 |
+
from model import StableMultiDiffusionSDXLPipeline
|
43 |
+
from util import seed_everything
|
44 |
+
from prompt_util import preprocess_prompts, _quality_dict, _style_dict
|
45 |
+
from share_btn import community_icon_html, loading_icon_html, share_js
|
46 |
+
|
47 |
+
|
48 |
+
### Utils
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def log_state(state):
|
54 |
+
pprint(vars(opt))
|
55 |
+
if isinstance(state, gr.State):
|
56 |
+
state = state.value
|
57 |
+
pprint(vars(state))
|
58 |
+
|
59 |
+
|
60 |
+
def is_empty_image(im: Image.Image) -> bool:
|
61 |
+
if im is None:
|
62 |
+
return True
|
63 |
+
im = np.array(im)
|
64 |
+
has_alpha = (im.shape[2] == 4)
|
65 |
+
if not has_alpha:
|
66 |
+
return False
|
67 |
+
elif im.sum() == 0:
|
68 |
+
return True
|
69 |
+
else:
|
70 |
+
return False
|
71 |
+
|
72 |
+
|
73 |
+
### Argument passing
|
74 |
+
|
75 |
+
parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion with SDXL support.')
|
76 |
+
parser.add_argument('-H', '--height', type=int, default=1024)
|
77 |
+
parser.add_argument('-W', '--width', type=int, default=2560)
|
78 |
+
parser.add_argument('--model', type=str, default=None, help='Hugging face model repository or local path for a SD1.5 model checkpoint to run.')
|
79 |
+
parser.add_argument('--bootstrap_steps', type=int, default=1)
|
80 |
+
parser.add_argument('--seed', type=int, default=-1)
|
81 |
+
parser.add_argument('--device', type=int, default=0)
|
82 |
+
parser.add_argument('--port', type=int, default=8000)
|
83 |
+
opt = parser.parse_args()
|
84 |
+
|
85 |
+
|
86 |
+
### Global variables and data structures
|
87 |
+
|
88 |
+
device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
|
89 |
+
|
90 |
+
|
91 |
+
if opt.model is None:
|
92 |
+
model_dict = {
|
93 |
+
'Animagine XL 3.1': 'cagliostrolab/animagine-xl-3.1',
|
94 |
+
}
|
95 |
+
else:
|
96 |
+
if opt.model.endswith('.safetensors'):
|
97 |
+
opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
|
98 |
+
model_dict = {os.path.splitext(os.path.basename(opt.model))[0]: opt.model}
|
99 |
+
|
100 |
+
models = {
|
101 |
+
k: StableMultiDiffusionSDXLPipeline(device, hf_key=v, has_i2t=False)
|
102 |
+
for k, v in model_dict.items()
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
prompt_suggestions = [
|
107 |
+
'1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
|
108 |
+
'1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
|
109 |
+
'1girl, arima kana, oshi no ko, solo, upper body, from behind',
|
110 |
+
]
|
111 |
+
|
112 |
+
opt.max_palettes = 5
|
113 |
+
opt.default_prompt_strength = 1.0
|
114 |
+
opt.default_mask_strength = 1.0
|
115 |
+
opt.default_mask_std = 0.0
|
116 |
+
opt.default_negative_prompt = (
|
117 |
+
'nsfw, worst quality, bad quality, normal quality, cropped, framed'
|
118 |
+
)
|
119 |
+
opt.verbose = True
|
120 |
+
opt.colors = [
|
121 |
+
'#000000',
|
122 |
+
'#2692F3',
|
123 |
+
'#F89E12',
|
124 |
+
'#16C232',
|
125 |
+
'#F92F6C',
|
126 |
+
'#AC6AEB',
|
127 |
+
# '#92C62C',
|
128 |
+
# '#92C6EC',
|
129 |
+
# '#FECAC0',
|
130 |
+
]
|
131 |
+
|
132 |
+
|
133 |
+
### Event handlers
|
134 |
+
|
135 |
+
def add_palette(state):
|
136 |
+
old_actives = state.active_palettes
|
137 |
+
state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
|
138 |
+
|
139 |
+
if opt.verbose:
|
140 |
+
log_state(state)
|
141 |
+
|
142 |
+
if state.active_palettes != old_actives:
|
143 |
+
return [state] + [
|
144 |
+
gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
|
145 |
+
] + [
|
146 |
+
gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
|
147 |
+
for i in range(opt.max_palettes)
|
148 |
+
]
|
149 |
+
else:
|
150 |
+
return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
|
151 |
+
|
152 |
+
|
153 |
+
def select_palette(state, button, idx):
|
154 |
+
if idx < 0 or idx > opt.max_palettes:
|
155 |
+
idx = 0
|
156 |
+
old_idx = state.current_palette
|
157 |
+
if old_idx == idx:
|
158 |
+
return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
|
159 |
+
|
160 |
+
state.current_palette = idx
|
161 |
+
|
162 |
+
if opt.verbose:
|
163 |
+
log_state(state)
|
164 |
+
|
165 |
+
updates = [state] + [
|
166 |
+
gr.update() if i not in (idx, old_idx) else
|
167 |
+
gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
|
168 |
+
for i in range(opt.max_palettes + 1)
|
169 |
+
]
|
170 |
+
label = 'Background' if idx == 0 else f'Palette {idx}'
|
171 |
+
updates.extend([
|
172 |
+
gr.update(value=button, interactive=(idx > 0)),
|
173 |
+
gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
|
174 |
+
gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
|
175 |
+
(
|
176 |
+
gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
|
177 |
+
gr.update(value=opt.default_mask_strength, interactive=False)
|
178 |
+
),
|
179 |
+
(
|
180 |
+
gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
|
181 |
+
gr.update(value=opt.default_prompt_strength, interactive=False)
|
182 |
+
),
|
183 |
+
(
|
184 |
+
gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
|
185 |
+
gr.update(value=opt.default_mask_std, interactive=False)
|
186 |
+
),
|
187 |
+
])
|
188 |
+
return updates
|
189 |
+
|
190 |
+
|
191 |
+
def change_prompt_strength(state, strength):
|
192 |
+
if state.current_palette == 0:
|
193 |
+
return state
|
194 |
+
|
195 |
+
state.prompt_strengths[state.current_palette - 1] = strength
|
196 |
+
if opt.verbose:
|
197 |
+
log_state(state)
|
198 |
+
|
199 |
+
return state
|
200 |
+
|
201 |
+
|
202 |
+
def change_std(state, std):
|
203 |
+
if state.current_palette == 0:
|
204 |
+
return state
|
205 |
+
|
206 |
+
state.mask_stds[state.current_palette - 1] = std
|
207 |
+
if opt.verbose:
|
208 |
+
log_state(state)
|
209 |
+
|
210 |
+
return state
|
211 |
+
|
212 |
+
|
213 |
+
def change_mask_strength(state, strength):
|
214 |
+
if state.current_palette == 0:
|
215 |
+
return state
|
216 |
+
|
217 |
+
state.mask_strengths[state.current_palette - 1] = strength
|
218 |
+
if opt.verbose:
|
219 |
+
log_state(state)
|
220 |
+
|
221 |
+
return state
|
222 |
+
|
223 |
+
|
224 |
+
def reset_seed(state, seed):
|
225 |
+
state.seed = seed
|
226 |
+
if opt.verbose:
|
227 |
+
log_state(state)
|
228 |
+
|
229 |
+
return state
|
230 |
+
|
231 |
+
def rename_prompt(state, name):
|
232 |
+
state.prompt_names[state.current_palette] = name
|
233 |
+
if opt.verbose:
|
234 |
+
log_state(state)
|
235 |
+
|
236 |
+
return [state] + [
|
237 |
+
gr.update() if i != state.current_palette else gr.update(value=name)
|
238 |
+
for i in range(opt.max_palettes + 1)
|
239 |
+
]
|
240 |
+
|
241 |
+
|
242 |
+
def change_prompt(state, prompt):
|
243 |
+
state.prompts[state.current_palette] = prompt
|
244 |
+
if opt.verbose:
|
245 |
+
log_state(state)
|
246 |
+
|
247 |
+
return state
|
248 |
+
|
249 |
+
|
250 |
+
def change_neg_prompt(state, neg_prompt):
|
251 |
+
state.neg_prompts[state.current_palette] = neg_prompt
|
252 |
+
if opt.verbose:
|
253 |
+
log_state(state)
|
254 |
+
|
255 |
+
return state
|
256 |
+
|
257 |
+
|
258 |
+
def select_model(state, model_id):
|
259 |
+
state.model_id = model_id
|
260 |
+
if opt.verbose:
|
261 |
+
log_state(state)
|
262 |
+
|
263 |
+
return state
|
264 |
+
|
265 |
+
|
266 |
+
def select_style(state, style_name):
|
267 |
+
state.style_name = style_name
|
268 |
+
if opt.verbose:
|
269 |
+
log_state(state)
|
270 |
+
|
271 |
+
return state
|
272 |
+
|
273 |
+
|
274 |
+
def select_quality(state, quality_name):
|
275 |
+
state.quality_name = quality_name
|
276 |
+
if opt.verbose:
|
277 |
+
log_state(state)
|
278 |
+
|
279 |
+
return state
|
280 |
+
|
281 |
+
|
282 |
+
def import_state(state, json_text):
|
283 |
+
current_palette = state.current_palette
|
284 |
+
# active_palettes = state.active_palettes
|
285 |
+
state = argparse.Namespace(**json.loads(json_text))
|
286 |
+
state.active_palettes = opt.max_palettes
|
287 |
+
return [state] + [
|
288 |
+
gr.update(value=v, visible=True) for v in state.prompt_names
|
289 |
+
] + [
|
290 |
+
state.model_id,
|
291 |
+
state.style_name,
|
292 |
+
state.quality_name,
|
293 |
+
state.prompts[current_palette],
|
294 |
+
state.prompt_names[current_palette],
|
295 |
+
state.neg_prompts[current_palette],
|
296 |
+
state.prompt_strengths[current_palette - 1],
|
297 |
+
state.mask_strengths[current_palette - 1],
|
298 |
+
state.mask_stds[current_palette - 1],
|
299 |
+
state.seed,
|
300 |
+
]
|
301 |
+
|
302 |
+
|
303 |
+
### Main worker
|
304 |
+
|
305 |
+
def generate(state, *args, **kwargs):
|
306 |
+
return models[state.model_id](*args, **kwargs)
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
def run(state, drawpad):
|
311 |
+
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
312 |
+
print('Generate!')
|
313 |
+
|
314 |
+
background = drawpad['background'].convert('RGBA')
|
315 |
+
inpainting_mode = np.asarray(background).sum() != 0
|
316 |
+
print('Inpainting mode: ', inpainting_mode)
|
317 |
+
|
318 |
+
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
319 |
+
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
320 |
+
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
321 |
+
|
322 |
+
palette = torch.tensor([
|
323 |
+
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
324 |
+
for s in opt.colors[1:]
|
325 |
+
]) # (N, 3)
|
326 |
+
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
327 |
+
has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
328 |
+
print('Has mask: ', has_masks)
|
329 |
+
masks = masks * foreground_mask
|
330 |
+
masks = masks[has_masks]
|
331 |
+
|
332 |
+
if inpainting_mode:
|
333 |
+
prompts = [state.prompts[v + 1] for v in has_masks]
|
334 |
+
negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
|
335 |
+
mask_strengths = [state.mask_strengths[v] for v in has_masks]
|
336 |
+
mask_stds = [state.mask_stds[v] for v in has_masks]
|
337 |
+
prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
|
338 |
+
else:
|
339 |
+
masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
|
340 |
+
prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
|
341 |
+
negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
|
342 |
+
mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
343 |
+
mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
344 |
+
prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
|
345 |
+
|
346 |
+
prompts, negative_prompts = preprocess_prompts(
|
347 |
+
prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
|
348 |
+
|
349 |
+
return generate(
|
350 |
+
state,
|
351 |
+
prompts,
|
352 |
+
negative_prompts,
|
353 |
+
masks=masks,
|
354 |
+
mask_strengths=mask_strengths,
|
355 |
+
mask_stds=mask_stds,
|
356 |
+
prompt_strengths=prompt_strengths,
|
357 |
+
background=background.convert('RGB'),
|
358 |
+
background_prompt=state.prompts[0],
|
359 |
+
background_negative_prompt=state.neg_prompts[0],
|
360 |
+
height=opt.height,
|
361 |
+
width=opt.width,
|
362 |
+
bootstrap_steps=2,
|
363 |
+
guidance_scale=0,
|
364 |
+
)
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
### Load examples
|
369 |
+
|
370 |
+
|
371 |
+
root = pathlib.Path(__file__).parent
|
372 |
+
print(root)
|
373 |
+
example_root = os.path.join(root, 'examples')
|
374 |
+
example_images = glob.glob(os.path.join(example_root, '*.png'))
|
375 |
+
example_images = [Image.open(i) for i in example_images]
|
376 |
+
|
377 |
+
with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
|
378 |
+
prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
|
379 |
+
|
380 |
+
with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
|
381 |
+
prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
|
382 |
+
|
383 |
+
with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
|
384 |
+
prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
|
385 |
+
|
386 |
+
with open(os.path.join(example_root, 'prompt_props.txt')) as f:
|
387 |
+
prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
|
388 |
+
prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
|
389 |
+
|
390 |
+
prompt_background = lambda: random.choice(prompts_background)
|
391 |
+
prompt_girl = lambda: random.choice(prompts_girl)
|
392 |
+
prompt_boy = lambda: random.choice(prompts_boy)
|
393 |
+
prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
|
394 |
+
|
395 |
+
|
396 |
+
### Main application
|
397 |
+
|
398 |
+
css = f"""
|
399 |
+
#run-button {{
|
400 |
+
font-size: 30pt;
|
401 |
+
background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
|
402 |
+
margin: 0;
|
403 |
+
padding: 15px 45px;
|
404 |
+
text-align: center;
|
405 |
+
text-transform: uppercase;
|
406 |
+
transition: 0.5s;
|
407 |
+
background-size: 200% auto;
|
408 |
+
color: white;
|
409 |
+
box-shadow: 0 0 20px #eee;
|
410 |
+
border-radius: 10px;
|
411 |
+
display: block;
|
412 |
+
background-position: right center;
|
413 |
+
}}
|
414 |
+
|
415 |
+
#run-button:hover {{
|
416 |
+
background-position: left center;
|
417 |
+
color: #fff;
|
418 |
+
text-decoration: none;
|
419 |
+
}}
|
420 |
+
|
421 |
+
#semantic-palette {{
|
422 |
+
border-style: solid;
|
423 |
+
border-width: 0.2em;
|
424 |
+
border-color: #eee;
|
425 |
+
}}
|
426 |
+
|
427 |
+
#semantic-palette:hover {{
|
428 |
+
box-shadow: 0 0 20px #eee;
|
429 |
+
}}
|
430 |
+
|
431 |
+
#output-screen {{
|
432 |
+
width: 100%;
|
433 |
+
aspect-ratio: {opt.width} / {opt.height};
|
434 |
+
}}
|
435 |
+
|
436 |
+
.layer-wrap {{
|
437 |
+
display: none;
|
438 |
+
}}
|
439 |
+
"""
|
440 |
+
|
441 |
+
for i in range(opt.max_palettes + 1):
|
442 |
+
css = css + f"""
|
443 |
+
.secondary#semantic-palette-{i} {{
|
444 |
+
background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
|
445 |
+
color: white;
|
446 |
+
}}
|
447 |
+
|
448 |
+
.primary#semantic-palette-{i} {{
|
449 |
+
background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
|
450 |
+
color: white;
|
451 |
+
}}
|
452 |
+
"""
|
453 |
+
|
454 |
+
|
455 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
456 |
+
|
457 |
+
iface = argparse.Namespace()
|
458 |
+
|
459 |
+
def _define_state():
|
460 |
+
state = argparse.Namespace()
|
461 |
+
|
462 |
+
# Cursor.
|
463 |
+
state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
|
464 |
+
state.model_id = list(model_dict.keys())[0]
|
465 |
+
state.style_name = '(None)'
|
466 |
+
state.quality_name = 'Standard v3.1'
|
467 |
+
|
468 |
+
# State variables (one-hot).
|
469 |
+
state.active_palettes = 1
|
470 |
+
|
471 |
+
# Front-end initialized to the default values.
|
472 |
+
prompt_props_ = prompt_props()
|
473 |
+
state.prompt_names = [
|
474 |
+
'π Background',
|
475 |
+
'π§ Girl',
|
476 |
+
'π¦ Boy',
|
477 |
+
] + prompt_props_ + ['π¨ New Palette' for _ in range(opt.max_palettes - 5)]
|
478 |
+
state.prompts = [
|
479 |
+
prompt_background(),
|
480 |
+
prompt_girl(),
|
481 |
+
prompt_boy(),
|
482 |
+
] + [prompts_props[k] for k in prompt_props_] + ['' for _ in range(opt.max_palettes - 5)]
|
483 |
+
state.neg_prompts = [
|
484 |
+
opt.default_negative_prompt
|
485 |
+
+ (', humans, humans, humans' if i == 0 else '')
|
486 |
+
for i in range(opt.max_palettes + 1)
|
487 |
+
]
|
488 |
+
state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
|
489 |
+
state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
|
490 |
+
state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
|
491 |
+
state.seed = opt.seed
|
492 |
+
return state
|
493 |
+
|
494 |
+
state = gr.State(value=_define_state)
|
495 |
+
|
496 |
+
|
497 |
+
### Demo user interface
|
498 |
+
|
499 |
+
gr.HTML(
|
500 |
+
"""
|
501 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
502 |
+
<div>
|
503 |
+
<h1>π§ Semantic Paint X Animagine XL 3.1 π¨</h1>
|
504 |
+
<h5 style="margin: 0;">powered by</h5>
|
505 |
+
<h3>StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control</h3>
|
506 |
+
<h5 style="margin: 0;">and</h5>
|
507 |
+
<h3>Animagine XL 3.1 by Cagliostro Research Lab</h3>
|
508 |
+
<h5 style="margin: 0;">If you β€οΈ our project, please visit our Github and give us a π!</h5>
|
509 |
+
</br>
|
510 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
511 |
+
<a href='https://arxiv.org/abs/2403.09055'>
|
512 |
+
<img src="https://img.shields.io/badge/arXiv-2403.09055-red">
|
513 |
+
</a>
|
514 |
+
|
515 |
+
<a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
|
516 |
+
<img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
|
517 |
+
</a>
|
518 |
+
|
519 |
+
<a href='https://github.com/ironjr/StreamMultiDiffusion'>
|
520 |
+
<img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
|
521 |
+
</a>
|
522 |
+
|
523 |
+
<a href='https://twitter.com/_ironjr_'>
|
524 |
+
<img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
|
525 |
+
</a>
|
526 |
+
|
527 |
+
<a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
|
528 |
+
<img src='https://img.shields.io/badge/license-MIT-lightgrey'>
|
529 |
+
</a>
|
530 |
+
|
531 |
+
<a href='https://huggingface.co/papers/2403.09055'>
|
532 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Paper-yellow'>
|
533 |
+
</a>
|
534 |
+
|
535 |
+
<a href='https://huggingface.co/spaces/ironjr/SemanticPalette'>
|
536 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-v1.5-yellow'>
|
537 |
+
</a>
|
538 |
+
|
539 |
+
<a href='https://huggingface.co/cagliostrolab/animagine-xl-3.1'>
|
540 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Model-AnimagineXL3.1-yellow'>
|
541 |
+
</a>
|
542 |
+
</div>
|
543 |
+
</div>
|
544 |
+
</div>
|
545 |
+
<div>
|
546 |
+
</br>
|
547 |
+
</div>
|
548 |
+
"""
|
549 |
+
)
|
550 |
+
|
551 |
+
with gr.Row():
|
552 |
+
|
553 |
+
iface.image_slot = gr.Image(
|
554 |
+
interactive=False,
|
555 |
+
show_label=False,
|
556 |
+
show_download_button=True,
|
557 |
+
type='pil',
|
558 |
+
label='Generated Result',
|
559 |
+
elem_id='output-screen',
|
560 |
+
value=lambda: random.choice(example_images),
|
561 |
+
)
|
562 |
+
|
563 |
+
with gr.Row():
|
564 |
+
|
565 |
+
with gr.Column(scale=1):
|
566 |
+
|
567 |
+
with gr.Group(elem_id='semantic-palette'):
|
568 |
+
|
569 |
+
gr.HTML(
|
570 |
+
"""
|
571 |
+
<div style="justify-content: center; align-items: center;">
|
572 |
+
<br/>
|
573 |
+
<h3 style="margin: 0; text-align: center;"><b>π§ Semantic Palette π¨</b></h3>
|
574 |
+
<br/>
|
575 |
+
</div>
|
576 |
+
"""
|
577 |
+
)
|
578 |
+
|
579 |
+
iface.btn_semantics = [gr.Button(
|
580 |
+
value=state.value.prompt_names[0],
|
581 |
+
variant='primary',
|
582 |
+
elem_id='semantic-palette-0',
|
583 |
+
)]
|
584 |
+
for i in range(opt.max_palettes):
|
585 |
+
iface.btn_semantics.append(gr.Button(
|
586 |
+
value=state.value.prompt_names[i + 1],
|
587 |
+
variant='secondary',
|
588 |
+
visible=(i < state.value.active_palettes),
|
589 |
+
elem_id=f'semantic-palette-{i + 1}'
|
590 |
+
))
|
591 |
+
|
592 |
+
iface.btn_add_palette = gr.Button(
|
593 |
+
value='Create New Semantic Brush',
|
594 |
+
variant='primary',
|
595 |
+
)
|
596 |
+
|
597 |
+
with gr.Accordion(label='Import/Export Semantic Palette', open=False):
|
598 |
+
iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
|
599 |
+
iface.json_state_export = gr.JSON(label='Exported Palette')
|
600 |
+
iface.btn_export_state = gr.Button("Export Palette β‘οΈ JSON", variant='primary')
|
601 |
+
iface.btn_import_state = gr.Button("Import JSON β‘οΈ Palette", variant='secondary')
|
602 |
+
|
603 |
+
gr.HTML(
|
604 |
+
"""
|
605 |
+
<div>
|
606 |
+
</br>
|
607 |
+
</div>
|
608 |
+
<div style="justify-content: center; align-items: center;">
|
609 |
+
<h3 style="margin: 0; text-align: center;"><b>βUsageβ</b></h3>
|
610 |
+
</br>
|
611 |
+
<div style="justify-content: center; align-items: left; text-align: left;">
|
612 |
+
<p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
|
613 |
+
<p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
|
614 |
+
<p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
|
615 |
+
<p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
|
616 |
+
<p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
|
617 |
+
<p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
|
618 |
+
</div>
|
619 |
+
</div>
|
620 |
+
"""
|
621 |
+
)
|
622 |
+
|
623 |
+
gr.HTML(
|
624 |
+
"""
|
625 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
626 |
+
<h5 style="margin: 0;"><b>... or run in your own π€ space!</b></h5>
|
627 |
+
</div>
|
628 |
+
"""
|
629 |
+
)
|
630 |
+
|
631 |
+
gr.DuplicateButton()
|
632 |
+
|
633 |
+
with gr.Column(scale=4):
|
634 |
+
|
635 |
+
with gr.Row():
|
636 |
+
|
637 |
+
with gr.Column(scale=3):
|
638 |
+
|
639 |
+
iface.ctrl_semantic = gr.ImageEditor(
|
640 |
+
image_mode='RGBA',
|
641 |
+
sources=['upload', 'clipboard', 'webcam'],
|
642 |
+
transforms=['crop'],
|
643 |
+
crop_size=(opt.width, opt.height),
|
644 |
+
brush=gr.Brush(
|
645 |
+
colors=opt.colors[1:],
|
646 |
+
color_mode="fixed",
|
647 |
+
),
|
648 |
+
type='pil',
|
649 |
+
label='Semantic Drawpad',
|
650 |
+
elem_id='drawpad',
|
651 |
+
)
|
652 |
+
|
653 |
+
with gr.Column(scale=1):
|
654 |
+
|
655 |
+
iface.btn_generate = gr.Button(
|
656 |
+
value='Generate!',
|
657 |
+
variant='primary',
|
658 |
+
# scale=1,
|
659 |
+
elem_id='run-button'
|
660 |
+
)
|
661 |
+
with gr.Group(elem_id="share-btn-container"):
|
662 |
+
gr.HTML(community_icon_html)
|
663 |
+
gr.HTML(loading_icon_html)
|
664 |
+
iface.btn_share = gr.Button("Share with Community", elem_id="share-btn")
|
665 |
+
|
666 |
+
iface.model_select = gr.Radio(
|
667 |
+
list(model_dict.keys()),
|
668 |
+
label='Stable Diffusion Checkpoint',
|
669 |
+
info='Choose your favorite style.',
|
670 |
+
value=state.value.model_id,
|
671 |
+
)
|
672 |
+
|
673 |
+
with gr.Accordion(label='Prompt Engineering', open=True):
|
674 |
+
iface.quality_select = gr.Dropdown(
|
675 |
+
label='Quality Presets',
|
676 |
+
interactive=True,
|
677 |
+
choices=list(_quality_dict.keys()),
|
678 |
+
value='Standard v3.1',
|
679 |
+
)
|
680 |
+
iface.style_select = gr.Radio(
|
681 |
+
label='Style Preset',
|
682 |
+
container=True,
|
683 |
+
interactive=True,
|
684 |
+
choices=list(_style_dict.keys()),
|
685 |
+
value='(None)',
|
686 |
+
)
|
687 |
+
|
688 |
+
with gr.Group(elem_id='control-panel'):
|
689 |
+
|
690 |
+
with gr.Row():
|
691 |
+
iface.tbox_prompt = gr.Textbox(
|
692 |
+
label='Edit Prompt for Background',
|
693 |
+
info='What do you want to draw?',
|
694 |
+
value=state.value.prompts[0],
|
695 |
+
placeholder=lambda: random.choice(prompt_suggestions),
|
696 |
+
scale=2,
|
697 |
+
)
|
698 |
+
|
699 |
+
iface.tbox_name = gr.Textbox(
|
700 |
+
label='Edit Brush Name',
|
701 |
+
info='Just for your convenience.',
|
702 |
+
value=state.value.prompt_names[0],
|
703 |
+
placeholder='π Background',
|
704 |
+
scale=1,
|
705 |
+
)
|
706 |
+
|
707 |
+
with gr.Row():
|
708 |
+
iface.tbox_neg_prompt = gr.Textbox(
|
709 |
+
label='Edit Negative Prompt for Background',
|
710 |
+
info='Add unwanted objects for this semantic brush.',
|
711 |
+
value=opt.default_negative_prompt,
|
712 |
+
scale=2,
|
713 |
+
)
|
714 |
+
|
715 |
+
iface.slider_strength = gr.Slider(
|
716 |
+
label='Prompt Strength',
|
717 |
+
info='Blends fg & bg in the prompt level, >0.8 Preferred.',
|
718 |
+
minimum=0.5,
|
719 |
+
maximum=1.0,
|
720 |
+
value=opt.default_prompt_strength,
|
721 |
+
scale=1,
|
722 |
+
)
|
723 |
+
|
724 |
+
with gr.Row():
|
725 |
+
iface.slider_alpha = gr.Slider(
|
726 |
+
label='Mask Alpha',
|
727 |
+
info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
|
728 |
+
minimum=0.5,
|
729 |
+
maximum=1.0,
|
730 |
+
value=opt.default_mask_strength,
|
731 |
+
)
|
732 |
+
|
733 |
+
iface.slider_std = gr.Slider(
|
734 |
+
label='Mask Blur STD',
|
735 |
+
info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
|
736 |
+
minimum=0.0001,
|
737 |
+
maximum=100.0,
|
738 |
+
value=opt.default_mask_std,
|
739 |
+
)
|
740 |
+
|
741 |
+
iface.slider_seed = gr.Slider(
|
742 |
+
label='Seed',
|
743 |
+
info='The global seed.',
|
744 |
+
minimum=-1,
|
745 |
+
maximum=2147483647,
|
746 |
+
step=1,
|
747 |
+
value=opt.seed,
|
748 |
+
)
|
749 |
+
|
750 |
+
### Attach event handlers
|
751 |
+
|
752 |
+
for idx, btn in enumerate(iface.btn_semantics):
|
753 |
+
btn.click(
|
754 |
+
fn=partial(select_palette, idx=idx),
|
755 |
+
inputs=[state, btn],
|
756 |
+
outputs=[state] + iface.btn_semantics + [
|
757 |
+
iface.tbox_name,
|
758 |
+
iface.tbox_prompt,
|
759 |
+
iface.tbox_neg_prompt,
|
760 |
+
iface.slider_alpha,
|
761 |
+
iface.slider_strength,
|
762 |
+
iface.slider_std,
|
763 |
+
],
|
764 |
+
api_name=f'select_palette_{idx}',
|
765 |
+
)
|
766 |
+
|
767 |
+
iface.btn_add_palette.click(
|
768 |
+
fn=add_palette,
|
769 |
+
inputs=state,
|
770 |
+
outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
|
771 |
+
api_name='create_new',
|
772 |
+
)
|
773 |
+
|
774 |
+
iface.btn_generate.click(
|
775 |
+
fn=run,
|
776 |
+
inputs=[state, iface.ctrl_semantic],
|
777 |
+
outputs=iface.image_slot,
|
778 |
+
api_name='run',
|
779 |
+
)
|
780 |
+
|
781 |
+
iface.slider_alpha.input(
|
782 |
+
fn=change_mask_strength,
|
783 |
+
inputs=[state, iface.slider_alpha],
|
784 |
+
outputs=state,
|
785 |
+
api_name='change_alpha',
|
786 |
+
)
|
787 |
+
iface.slider_std.input(
|
788 |
+
fn=change_std,
|
789 |
+
inputs=[state, iface.slider_std],
|
790 |
+
outputs=state,
|
791 |
+
api_name='change_std',
|
792 |
+
)
|
793 |
+
iface.slider_strength.input(
|
794 |
+
fn=change_prompt_strength,
|
795 |
+
inputs=[state, iface.slider_strength],
|
796 |
+
outputs=state,
|
797 |
+
api_name='change_strength',
|
798 |
+
)
|
799 |
+
iface.slider_seed.input(
|
800 |
+
fn=reset_seed,
|
801 |
+
inputs=[state, iface.slider_seed],
|
802 |
+
outputs=state,
|
803 |
+
api_name='reset_seed',
|
804 |
+
)
|
805 |
+
|
806 |
+
iface.tbox_name.input(
|
807 |
+
fn=rename_prompt,
|
808 |
+
inputs=[state, iface.tbox_name],
|
809 |
+
outputs=[state] + iface.btn_semantics,
|
810 |
+
api_name='prompt_rename',
|
811 |
+
)
|
812 |
+
iface.tbox_prompt.input(
|
813 |
+
fn=change_prompt,
|
814 |
+
inputs=[state, iface.tbox_prompt],
|
815 |
+
outputs=state,
|
816 |
+
api_name='prompt_edit',
|
817 |
+
)
|
818 |
+
iface.tbox_neg_prompt.input(
|
819 |
+
fn=change_neg_prompt,
|
820 |
+
inputs=[state, iface.tbox_neg_prompt],
|
821 |
+
outputs=state,
|
822 |
+
api_name='neg_prompt_edit',
|
823 |
+
)
|
824 |
+
|
825 |
+
iface.model_select.change(
|
826 |
+
fn=select_model,
|
827 |
+
inputs=[state, iface.model_select],
|
828 |
+
outputs=state,
|
829 |
+
api_name='model_select',
|
830 |
+
)
|
831 |
+
iface.style_select.change(
|
832 |
+
fn=select_style,
|
833 |
+
inputs=[state, iface.style_select],
|
834 |
+
outputs=state,
|
835 |
+
api_name='style_select',
|
836 |
+
)
|
837 |
+
iface.quality_select.change(
|
838 |
+
fn=select_quality,
|
839 |
+
inputs=[state, iface.quality_select],
|
840 |
+
outputs=state,
|
841 |
+
api_name='quality_select',
|
842 |
+
)
|
843 |
+
|
844 |
+
iface.btn_share.click(None, [], [], _js=share_js)
|
845 |
+
|
846 |
+
iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
|
847 |
+
iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
|
848 |
+
state,
|
849 |
+
*iface.btn_semantics,
|
850 |
+
iface.model_select,
|
851 |
+
iface.style_select,
|
852 |
+
iface.quality_select,
|
853 |
+
iface.tbox_prompt,
|
854 |
+
iface.tbox_name,
|
855 |
+
iface.tbox_neg_prompt,
|
856 |
+
iface.slider_strength,
|
857 |
+
iface.slider_alpha,
|
858 |
+
iface.slider_std,
|
859 |
+
iface.slider_seed,
|
860 |
+
])
|
861 |
+
|
862 |
+
gr.HTML(
|
863 |
+
"""
|
864 |
+
<div class="footer">
|
865 |
+
<p>We thank <a href="https://cagliostrolab.net/">Cagliostro Research Lab</a> for their permission to use <a href="https://huggingface.co/cagliostrolab/animagine-xl-3.1">Animagine XL 3.1</a> model under academic purpose.
|
866 |
+
Note that the MIT license only applies to StreamMultiDiffusion and Semantic Palette demo app, but not Animagine XL 3.1 model, which is distributed under <a href="https://freedevproject.org/faipl-1.0-sd/">Fair AI Public License 1.0-SD</a>.
|
867 |
+
</p>
|
868 |
+
</div>
|
869 |
+
"""
|
870 |
+
)
|
871 |
+
|
872 |
+
if __name__ == '__main__':
|
873 |
+
demo.launch(server_port=opt.port)
|
examples/prompt_background.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Maximalism, best quality, high quality, no humans, background, clear sky, γ
black sky, starry universe, planets
|
2 |
+
Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
|
3 |
+
Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
|
4 |
+
Maximalism, best quality, high quality, no humans, background, galaxy
|
5 |
+
Maximalism, best quality, high quality, no humans, background, sky, daylight
|
6 |
+
Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
|
7 |
+
Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
|
8 |
+
Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
|
examples/prompt_background_advanced.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/prompt_boy.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1boy, looking at viewer, brown hair, blue shirt
|
2 |
+
1boy, looking at viewer, brown hair, red shirt
|
3 |
+
1boy, looking at viewer, brown hair, purple shirt
|
4 |
+
1boy, looking at viewer, brown hair, orange shirt
|
5 |
+
1boy, looking at viewer, brown hair, yellow shirt
|
6 |
+
1boy, looking at viewer, brown hair, green shirt
|
7 |
+
1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
|
8 |
+
1boy, looking back, short hair, renaissance cloths, noble boy
|
9 |
+
1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
|
10 |
+
1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
|
11 |
+
1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
|
12 |
+
1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
|
13 |
+
1boy, looking at viewer, black haired, old eastern cloth
|
14 |
+
1boy, looking back, messy hair, suit, short beard, noir
|
15 |
+
1boy, looking at viewer, cute face, light smile, starry eyes, jeans
|
examples/prompt_girl.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
|
2 |
+
1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
|
3 |
+
1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
|
4 |
+
1girl, looking at viewer, fantasy adventurer, backpack
|
5 |
+
1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
|
6 |
+
1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
|
7 |
+
1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
|
8 |
+
1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
|
9 |
+
1girl, looking at viewer, evil smile, very short hair, suit, evil genius
|
10 |
+
1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
|
11 |
+
1girl, looking at viewer, purple hair, happy face, black leather jacket
|
12 |
+
1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
|
13 |
+
1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
|
14 |
+
1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
|
15 |
+
1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
|
16 |
+
1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
|
examples/prompt_props.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
π― Palace, Gyeongbokgung palace
|
2 |
+
π³ Garden, Chinese garden
|
3 |
+
ποΈ Rome, Ancient city of Rome
|
4 |
+
𧱠Wall, Castle wall
|
5 |
+
π΄ Mars, Martian desert, Red rocky desert
|
6 |
+
π» Grassland, Grasslands
|
7 |
+
π‘ Village, A fantasy village
|
8 |
+
π Dragon, a flying chinese dragon
|
9 |
+
π Earth, Earth seen from ISS
|
10 |
+
π Space Station, the international space station
|
11 |
+
πͺ» Grassland, Rusty grassland with flowers
|
12 |
+
πΌοΈ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
|
13 |
+
ποΈ City Ruin, city, ruins, ruins, ruins, deserted
|
14 |
+
ποΈ Renaissance City, renaissance city, renaissance city, renaissance city
|
15 |
+
π· Flowers, Flower garden
|
16 |
+
πΌ Flowers, Flower garden, spring garden
|
17 |
+
πΉ Flowers, Flowers flowers, flowers
|
18 |
+
β°οΈ Dolomites Mountains, Dolomites
|
19 |
+
β°οΈ Himalayas Mountains, Himalayas
|
20 |
+
β°οΈ Alps Mountains, Alps
|
21 |
+
β°οΈ Mountains, Mountains
|
22 |
+
βοΈβ°οΈ Mountains, Winter mountains
|
23 |
+
π·β°οΈ Mountains, Spring mountains
|
24 |
+
πβ°οΈ Mountains, Summer mountains
|
25 |
+
π΅ Desert, A sandy desert, dunes
|
26 |
+
πͺ¨π΅ Desert, A rocky desert
|
27 |
+
π¦ Waterfall, A giant waterfall
|
28 |
+
π Ocean, Ocean
|
29 |
+
β±οΈ Seashore, Seashore
|
30 |
+
π
Sea Horizon, Sea horizon
|
31 |
+
π Lake, Clear blue lake
|
32 |
+
π» Computer, A giant supecomputer
|
33 |
+
π³ Tree, A giant tree
|
34 |
+
π³ Forest, A forest
|
35 |
+
π³π³ Forest, A dense forest
|
36 |
+
π² Forest, Winter forest
|
37 |
+
π΄ Forest, Summer forest, tropical forest
|
38 |
+
π Hat, A hat
|
39 |
+
πΆ Dog, Doggy body parts
|
40 |
+
π» Cat, A cat
|
41 |
+
π¦ Owl, A small sitting owl
|
42 |
+
π¦
Eagle, A small sitting eagle
|
43 |
+
π Rocket, A flying rocket
|
model.py
ADDED
@@ -0,0 +1,1410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
22 |
+
from diffusers import (
|
23 |
+
AutoencoderTiny,
|
24 |
+
StableDiffusionXLPipeline,
|
25 |
+
UNet2DConditionModel,
|
26 |
+
EulerDiscreteScheduler,
|
27 |
+
)
|
28 |
+
from diffusers.models.attention_processor import (
|
29 |
+
AttnProcessor2_0,
|
30 |
+
FusedAttnProcessor2_0,
|
31 |
+
LoRAAttnProcessor2_0,
|
32 |
+
LoRAXFormersAttnProcessor,
|
33 |
+
XFormersAttnProcessor,
|
34 |
+
)
|
35 |
+
from diffusers.loaders import (
|
36 |
+
StableDiffusionXLLoraLoaderMixin,
|
37 |
+
TextualInversionLoaderMixin,
|
38 |
+
)
|
39 |
+
from diffusers.utils import (
|
40 |
+
USE_PEFT_BACKEND,
|
41 |
+
logging,
|
42 |
+
)
|
43 |
+
from huggingface_hub import hf_hub_download
|
44 |
+
from safetensors.torch import load_file
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
47 |
+
|
48 |
+
import torch
|
49 |
+
import torch.nn as nn
|
50 |
+
import torch.nn.functional as F
|
51 |
+
import torchvision.transforms as T
|
52 |
+
from einops import rearrange
|
53 |
+
|
54 |
+
from typing import Tuple, List, Literal, Optional, Union
|
55 |
+
from tqdm import tqdm
|
56 |
+
from PIL import Image
|
57 |
+
|
58 |
+
from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
59 |
+
|
60 |
+
|
61 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
62 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
63 |
+
"""
|
64 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
65 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
66 |
+
"""
|
67 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
68 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
69 |
+
# rescale the results from guidance (fixes overexposure)
|
70 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
71 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
72 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
73 |
+
return noise_cfg
|
74 |
+
|
75 |
+
|
76 |
+
class StableMultiDiffusionSDXLPipeline(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
device: torch.device,
|
80 |
+
dtype: torch.dtype = torch.float16,
|
81 |
+
hf_key: Optional[str] = None,
|
82 |
+
lora_key: Optional[str] = None,
|
83 |
+
load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
|
84 |
+
default_mask_std: float = 1.0, # 8.0
|
85 |
+
default_mask_strength: float = 1.0,
|
86 |
+
default_prompt_strength: float = 1.0, # 8.0
|
87 |
+
default_bootstrap_steps: int = 1,
|
88 |
+
default_boostrap_mix_steps: float = 1.0,
|
89 |
+
default_bootstrap_leak_sensitivity: float = 0.2,
|
90 |
+
default_preprocess_mask_cover_alpha: float = 0.3,
|
91 |
+
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # # [0, 12, 25, 37], # Magic number.
|
92 |
+
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
|
93 |
+
has_i2t: bool = True,
|
94 |
+
lora_weight: float = 1.0,
|
95 |
+
) -> None:
|
96 |
+
r"""Stabilized MultiDiffusion for fast sampling.
|
97 |
+
|
98 |
+
Accelrated region-based text-to-image synthesis with Latent Consistency
|
99 |
+
Model while preserving mask fidelity and quality.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
device (torch.device): Specify CUDA device.
|
103 |
+
hf_key (Optional[str]): Custom StableDiffusion checkpoint for
|
104 |
+
stylized generation.
|
105 |
+
lora_key (Optional[str]): Custom Lightning LoRA for acceleration.
|
106 |
+
load_from_local (bool): Turn on if you have already downloaed LoRA
|
107 |
+
& Hugging Face hub is down.
|
108 |
+
default_mask_std (float): Preprocess mask with Gaussian blur with
|
109 |
+
specified standard deviation.
|
110 |
+
default_mask_strength (float): Preprocess mask by multiplying it
|
111 |
+
globally with the specified variable. Caution: extremely
|
112 |
+
sensitive. Recommended range: 0.98-1.
|
113 |
+
default_prompt_strength (float): Preprocess foreground prompts
|
114 |
+
globally by linearly interpolating its embedding with the
|
115 |
+
background prompt embeddint with specified mix ratio. Useful
|
116 |
+
control handle for foreground blending. Recommended range:
|
117 |
+
0.5-1.
|
118 |
+
default_bootstrap_steps (int): Bootstrapping stage steps to
|
119 |
+
encourage region separation. Recommended range: 1-3.
|
120 |
+
default_boostrap_mix_steps (float): Bootstrapping background is a
|
121 |
+
linear interpolation between background latent and the white
|
122 |
+
image latent. This handle controls the mix ratio. Available
|
123 |
+
range: 0-(number of bootstrapping inference steps). For
|
124 |
+
example, 2.3 means that for the first two steps, white image
|
125 |
+
is used as a bootstrapping background and in the third step,
|
126 |
+
mixture of white (0.3) and registered background (0.7) is used
|
127 |
+
as a bootstrapping background.
|
128 |
+
default_bootstrap_leak_sensitivity (float): Postprocessing at each
|
129 |
+
inference step by masking away the remaining bootstrap
|
130 |
+
backgrounds t Recommended range: 0-1.
|
131 |
+
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
132 |
+
where each mask covered by other masks is reduced in its alpha
|
133 |
+
value by this specified factor.
|
134 |
+
t_index_list (List[int]): The default scheduling for LCM scheduler.
|
135 |
+
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
136 |
+
defines the mask quantization modes. Details in the codes of
|
137 |
+
`self.process_mask`. Basically, this (subtly) controls the
|
138 |
+
smoothness of foreground-background blending. More continuous
|
139 |
+
means more blending, but smaller generated patch depending on
|
140 |
+
the mask standard deviation.
|
141 |
+
has_i2t (bool): Automatic background image to text prompt con-
|
142 |
+
version with BLIP-2 model. May not be necessary for the non-
|
143 |
+
streaming application.
|
144 |
+
lora_weight (float): Adjusts weight of the LCM/Lightning LoRA.
|
145 |
+
Heavily affects the overall quality!
|
146 |
+
"""
|
147 |
+
super().__init__()
|
148 |
+
|
149 |
+
self.device = device
|
150 |
+
self.dtype = dtype
|
151 |
+
|
152 |
+
self.default_mask_std = default_mask_std
|
153 |
+
self.default_mask_strength = default_mask_strength
|
154 |
+
self.default_prompt_strength = default_prompt_strength
|
155 |
+
self.default_t_list = t_index_list
|
156 |
+
self.default_bootstrap_steps = default_bootstrap_steps
|
157 |
+
self.default_boostrap_mix_steps = default_boostrap_mix_steps
|
158 |
+
self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity
|
159 |
+
self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha
|
160 |
+
self.mask_type = mask_type
|
161 |
+
|
162 |
+
# Create model.
|
163 |
+
print(f'[INFO] Loading Stable Diffusion...')
|
164 |
+
variant = None
|
165 |
+
model_ckpt = None
|
166 |
+
lora_ckpt = None
|
167 |
+
lightning_repo = 'ByteDance/SDXL-Lightning'
|
168 |
+
if hf_key is not None:
|
169 |
+
print(f'[INFO] Using Hugging Face custom model key: {hf_key}')
|
170 |
+
model_key = hf_key
|
171 |
+
lora_ckpt = 'sdxl_lightning_4step_lora.safetensors'
|
172 |
+
|
173 |
+
self.pipe = StableDiffusionXLPipeline.from_pretrained(model_key, variant=variant, torch_dtype=self.dtype).to(self.device)
|
174 |
+
self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning')
|
175 |
+
self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight])
|
176 |
+
self.pipe.fuse_lora()
|
177 |
+
else:
|
178 |
+
model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
179 |
+
variant = 'fp16'
|
180 |
+
model_ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
|
181 |
+
|
182 |
+
unet = UNet2DConditionModel.from_config(model_key, subfolder='unet').to(self.device, self.dtype)
|
183 |
+
unet.load_state_dict(load_file(hf_hub_download(lightning_repo, model_ckpt), device=self.device))
|
184 |
+
self.pipe = StableDiffusionXLPipeline.from_pretrained(model_key, unet=unet, torch_dtype=self.dtype, variant=variant).to(self.device)
|
185 |
+
|
186 |
+
# Create model
|
187 |
+
if has_i2t:
|
188 |
+
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
|
189 |
+
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
|
190 |
+
|
191 |
+
# Use SDXL-Lightning LoRA by default.
|
192 |
+
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
|
193 |
+
self.pipe.scheduler.config, timestep_spacing="trailing")
|
194 |
+
self.scheduler = self.pipe.scheduler
|
195 |
+
self.default_num_inference_steps = 4
|
196 |
+
self.default_guidance_scale = 0.0
|
197 |
+
|
198 |
+
if t_index_list is None:
|
199 |
+
self.prepare_lightning_schedule(
|
200 |
+
list(range(self.default_num_inference_steps)),
|
201 |
+
self.default_num_inference_steps,
|
202 |
+
)
|
203 |
+
else:
|
204 |
+
self.prepare_lightning_schedule(t_index_list, 50)
|
205 |
+
|
206 |
+
self.vae = self.pipe.vae
|
207 |
+
self.tokenizer = self.pipe.tokenizer
|
208 |
+
self.tokenizer_2 = self.pipe.tokenizer_2
|
209 |
+
self.text_encoder = self.pipe.text_encoder
|
210 |
+
self.text_encoder_2 = self.pipe.text_encoder_2
|
211 |
+
self.unet = self.pipe.unet
|
212 |
+
self.vae_scale_factor = self.pipe.vae_scale_factor
|
213 |
+
|
214 |
+
# Prepare white background for bootstrapping.
|
215 |
+
self.get_white_background(1024, 1024)
|
216 |
+
|
217 |
+
print(f'[INFO] Model is loaded!')
|
218 |
+
|
219 |
+
def prepare_lightning_schedule(
|
220 |
+
self,
|
221 |
+
t_index_list: Optional[List[int]] = None,
|
222 |
+
num_inference_steps: Optional[int] = None,
|
223 |
+
s_churn: float = 0.0,
|
224 |
+
s_tmin: float = 0.0,
|
225 |
+
s_tmax: float = float("inf"),
|
226 |
+
) -> None:
|
227 |
+
r"""Set up different inference schedule for the diffusion model.
|
228 |
+
|
229 |
+
You do not have to run this explicitly if you want to use the default
|
230 |
+
setting, but if you want other time schedules, run this function
|
231 |
+
between the module initialization and the main call.
|
232 |
+
|
233 |
+
Note:
|
234 |
+
- Recommended t_index_lists for LCMs:
|
235 |
+
- [0, 12, 25, 37]: Default schedule for 4 steps. Best for
|
236 |
+
panorama. Not recommended if you want to use bootstrapping.
|
237 |
+
Because bootstrapping stage affects the initial structuring
|
238 |
+
of the generated image & in this four step LCM, this is done
|
239 |
+
with only at the first step, the structure may be distorted.
|
240 |
+
- [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot-
|
241 |
+
strapping. Default initialization in this implementation.
|
242 |
+
- [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step
|
243 |
+
bootstrapping.
|
244 |
+
- Due to the characteristic of SD1.5 LCM LoRA, setting
|
245 |
+
`num_inference_steps` larger than 20 may results in overly blurry
|
246 |
+
and unrealistic images. Beware!
|
247 |
+
|
248 |
+
Args:
|
249 |
+
t_index_list (Optional[List[int]]): The specified scheduling step
|
250 |
+
regarding the maximum timestep as `num_inference_steps`, which
|
251 |
+
is by default, 50. That means that
|
252 |
+
`t_index_list=[0, 12, 25, 37]` is a relative time indices basd
|
253 |
+
on the full scale of 50. If None, reinitialize the module with
|
254 |
+
the default value.
|
255 |
+
num_inference_steps (Optional[int]): The maximum timestep of the
|
256 |
+
sampler. Defines relative scale of the `t_index_list`. Rarely
|
257 |
+
used in practice. If None, reinitialize the module with the
|
258 |
+
default value.
|
259 |
+
"""
|
260 |
+
if t_index_list is None:
|
261 |
+
t_index_list = self.default_t_list
|
262 |
+
if num_inference_steps is None:
|
263 |
+
num_inference_steps = self.default_num_inference_steps
|
264 |
+
|
265 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
266 |
+
self.timesteps = self.scheduler.timesteps[torch.tensor(t_index_list)]
|
267 |
+
|
268 |
+
# EulerDiscreteScheduler
|
269 |
+
|
270 |
+
self.sigmas = self.scheduler.sigmas[torch.tensor(t_index_list)]
|
271 |
+
self.sigmas_next = torch.cat([self.sigmas, self.sigmas.new_zeros(1)])[1:]
|
272 |
+
sigma_mask = torch.logical_and(s_tmin <= self.sigmas, self.sigmas <= s_tmax)
|
273 |
+
# self.gammas = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) * sigma_mask
|
274 |
+
self.gammas = min(s_churn / (num_inference_steps - 1), 2**0.5 - 1) * sigma_mask
|
275 |
+
self.sigma_hats = self.sigmas * (self.gammas + 1)
|
276 |
+
self.dt = self.sigmas_next - self.sigma_hats
|
277 |
+
|
278 |
+
noise_lvs = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
|
279 |
+
self.noise_lvs = noise_lvs[None, :, None, None, None]
|
280 |
+
self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
|
281 |
+
|
282 |
+
def upcast_vae(self):
|
283 |
+
dtype = self.vae.dtype
|
284 |
+
self.vae.to(dtype=torch.float32)
|
285 |
+
use_torch_2_0_or_xformers = isinstance(
|
286 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
287 |
+
(
|
288 |
+
AttnProcessor2_0,
|
289 |
+
XFormersAttnProcessor,
|
290 |
+
LoRAXFormersAttnProcessor,
|
291 |
+
LoRAAttnProcessor2_0,
|
292 |
+
FusedAttnProcessor2_0,
|
293 |
+
),
|
294 |
+
)
|
295 |
+
# if xformers or torch_2_0 is used attention block does not need
|
296 |
+
# to be in float32 which can save lots of memory
|
297 |
+
if use_torch_2_0_or_xformers:
|
298 |
+
self.vae.post_quant_conv.to(dtype)
|
299 |
+
self.vae.decoder.conv_in.to(dtype)
|
300 |
+
self.vae.decoder.mid_block.to(dtype)
|
301 |
+
|
302 |
+
def _get_add_time_ids(
|
303 |
+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
304 |
+
):
|
305 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
306 |
+
|
307 |
+
passed_add_embed_dim = (
|
308 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
309 |
+
)
|
310 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
311 |
+
|
312 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
313 |
+
raise ValueError(
|
314 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
315 |
+
)
|
316 |
+
|
317 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
318 |
+
return add_time_ids
|
319 |
+
|
320 |
+
def encode_prompt(
|
321 |
+
self,
|
322 |
+
prompt: str,
|
323 |
+
prompt_2: Optional[str] = None,
|
324 |
+
device: Optional[torch.device] = None,
|
325 |
+
num_images_per_prompt: int = 1,
|
326 |
+
do_classifier_free_guidance: bool = True,
|
327 |
+
negative_prompt: Optional[str] = None,
|
328 |
+
negative_prompt_2: Optional[str] = None,
|
329 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
330 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
331 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
332 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
333 |
+
lora_scale: Optional[float] = None,
|
334 |
+
clip_skip: Optional[int] = None,
|
335 |
+
):
|
336 |
+
r"""
|
337 |
+
Encodes the prompt into text encoder hidden states.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
prompt (`str` or `List[str]`, *optional*):
|
341 |
+
prompt to be encoded
|
342 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
343 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
344 |
+
used in both text-encoders
|
345 |
+
device: (`torch.device`):
|
346 |
+
torch device
|
347 |
+
num_images_per_prompt (`int`):
|
348 |
+
number of images that should be generated per prompt
|
349 |
+
do_classifier_free_guidance (`bool`):
|
350 |
+
whether to use classifier free guidance or not
|
351 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
352 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
353 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
354 |
+
less than `1`).
|
355 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
356 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
357 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
358 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
359 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
360 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
361 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
362 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
363 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
364 |
+
argument.
|
365 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
366 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
367 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
368 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
369 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
370 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
371 |
+
input argument.
|
372 |
+
lora_scale (`float`, *optional*):
|
373 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
374 |
+
clip_skip (`int`, *optional*):
|
375 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
376 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
377 |
+
"""
|
378 |
+
device = device or self._execution_device
|
379 |
+
|
380 |
+
# set lora scale so that monkey patched LoRA
|
381 |
+
# function of text encoder can correctly access it
|
382 |
+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
383 |
+
self._lora_scale = lora_scale
|
384 |
+
|
385 |
+
# dynamically adjust the LoRA scale
|
386 |
+
if self.text_encoder is not None:
|
387 |
+
if not USE_PEFT_BACKEND:
|
388 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
389 |
+
else:
|
390 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
391 |
+
|
392 |
+
if self.text_encoder_2 is not None:
|
393 |
+
if not USE_PEFT_BACKEND:
|
394 |
+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
395 |
+
else:
|
396 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
397 |
+
|
398 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
399 |
+
|
400 |
+
if prompt is not None:
|
401 |
+
batch_size = len(prompt)
|
402 |
+
else:
|
403 |
+
batch_size = prompt_embeds.shape[0]
|
404 |
+
|
405 |
+
# Define tokenizers and text encoders
|
406 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
407 |
+
text_encoders = (
|
408 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
409 |
+
)
|
410 |
+
|
411 |
+
if prompt_embeds is None:
|
412 |
+
prompt_2 = prompt_2 or prompt
|
413 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
414 |
+
|
415 |
+
# textual inversion: process multi-vector tokens if necessary
|
416 |
+
prompt_embeds_list = []
|
417 |
+
prompts = [prompt, prompt_2]
|
418 |
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
419 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
420 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
421 |
+
|
422 |
+
text_inputs = tokenizer(
|
423 |
+
prompt,
|
424 |
+
padding="max_length",
|
425 |
+
max_length=tokenizer.model_max_length,
|
426 |
+
truncation=True,
|
427 |
+
return_tensors="pt",
|
428 |
+
)
|
429 |
+
|
430 |
+
text_input_ids = text_inputs.input_ids
|
431 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
432 |
+
|
433 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
434 |
+
text_input_ids, untruncated_ids
|
435 |
+
):
|
436 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
437 |
+
logger.warning(
|
438 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
439 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
440 |
+
)
|
441 |
+
|
442 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
443 |
+
|
444 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
445 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
446 |
+
if clip_skip is None:
|
447 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
448 |
+
else:
|
449 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
450 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
451 |
+
|
452 |
+
prompt_embeds_list.append(prompt_embeds)
|
453 |
+
|
454 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
455 |
+
|
456 |
+
# get unconditional embeddings for classifier free guidance
|
457 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
458 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
459 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
460 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
461 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
462 |
+
negative_prompt = negative_prompt or ""
|
463 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
464 |
+
|
465 |
+
# normalize str to list
|
466 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
467 |
+
negative_prompt_2 = (
|
468 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
469 |
+
)
|
470 |
+
|
471 |
+
uncond_tokens: List[str]
|
472 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
473 |
+
raise TypeError(
|
474 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
475 |
+
f" {type(prompt)}."
|
476 |
+
)
|
477 |
+
elif batch_size != len(negative_prompt):
|
478 |
+
raise ValueError(
|
479 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
480 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
481 |
+
" the batch size of `prompt`."
|
482 |
+
)
|
483 |
+
else:
|
484 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
485 |
+
|
486 |
+
negative_prompt_embeds_list = []
|
487 |
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
488 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
489 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
490 |
+
|
491 |
+
max_length = prompt_embeds.shape[1]
|
492 |
+
uncond_input = tokenizer(
|
493 |
+
negative_prompt,
|
494 |
+
padding="max_length",
|
495 |
+
max_length=max_length,
|
496 |
+
truncation=True,
|
497 |
+
return_tensors="pt",
|
498 |
+
)
|
499 |
+
|
500 |
+
negative_prompt_embeds = text_encoder(
|
501 |
+
uncond_input.input_ids.to(device),
|
502 |
+
output_hidden_states=True,
|
503 |
+
)
|
504 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
505 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
506 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
507 |
+
|
508 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
509 |
+
|
510 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
511 |
+
|
512 |
+
if self.text_encoder_2 is not None:
|
513 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
514 |
+
else:
|
515 |
+
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
516 |
+
|
517 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
518 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
519 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
520 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
521 |
+
|
522 |
+
if do_classifier_free_guidance:
|
523 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
524 |
+
seq_len = negative_prompt_embeds.shape[1]
|
525 |
+
|
526 |
+
if self.text_encoder_2 is not None:
|
527 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
528 |
+
else:
|
529 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
530 |
+
|
531 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
532 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
533 |
+
|
534 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
535 |
+
bs_embed * num_images_per_prompt, -1
|
536 |
+
)
|
537 |
+
if do_classifier_free_guidance:
|
538 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
539 |
+
bs_embed * num_images_per_prompt, -1
|
540 |
+
)
|
541 |
+
|
542 |
+
if self.text_encoder is not None:
|
543 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
544 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
545 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
546 |
+
|
547 |
+
if self.text_encoder_2 is not None:
|
548 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
549 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
550 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
551 |
+
|
552 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
553 |
+
|
554 |
+
@torch.no_grad()
|
555 |
+
def get_text_prompts(self, image: Image.Image) -> str:
|
556 |
+
r"""A convenient method to extract text prompt from an image.
|
557 |
+
|
558 |
+
This is called if the user does not provide background prompt but only
|
559 |
+
the background image. We use BLIP-2 to automatically generate prompts.
|
560 |
+
|
561 |
+
Args:
|
562 |
+
image (Image.Image): A PIL image.
|
563 |
+
|
564 |
+
Returns:
|
565 |
+
A single string of text prompt.
|
566 |
+
"""
|
567 |
+
if hasattr(self, 'i2t_model'):
|
568 |
+
question = 'Question: What are in the image? Answer:'
|
569 |
+
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
570 |
+
out = self.i2t_model.generate(**inputs, max_new_tokens=77)
|
571 |
+
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
572 |
+
return prompt
|
573 |
+
else:
|
574 |
+
return ''
|
575 |
+
|
576 |
+
@torch.no_grad()
|
577 |
+
def encode_imgs(
|
578 |
+
self,
|
579 |
+
imgs: torch.Tensor,
|
580 |
+
generator: Optional[torch.Generator] = None,
|
581 |
+
vae: Optional[nn.Module] = None,
|
582 |
+
) -> torch.Tensor:
|
583 |
+
r"""A wrapper function for VAE encoder of the latent diffusion model.
|
584 |
+
|
585 |
+
Args:
|
586 |
+
imgs (torch.Tensor): An image to get StableDiffusion latents.
|
587 |
+
Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
|
588 |
+
generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
|
589 |
+
vae (Optional[nn.Module]): Explicitly specify VAE (used for
|
590 |
+
the demo application with TinyVAE).
|
591 |
+
|
592 |
+
Returns:
|
593 |
+
An image latent embedding with 1/8 size (depending on the auto-
|
594 |
+
encoder. Shape: (B, 4, H//8, W//8).
|
595 |
+
"""
|
596 |
+
def _retrieve_latents(
|
597 |
+
encoder_output: torch.Tensor,
|
598 |
+
generator: Optional[torch.Generator] = None,
|
599 |
+
sample_mode: str = 'sample',
|
600 |
+
):
|
601 |
+
if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
|
602 |
+
return encoder_output.latent_dist.sample(generator)
|
603 |
+
elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
|
604 |
+
return encoder_output.latent_dist.mode()
|
605 |
+
elif hasattr(encoder_output, 'latents'):
|
606 |
+
return encoder_output.latents
|
607 |
+
else:
|
608 |
+
raise AttributeError('Could not access latents of provided encoder_output')
|
609 |
+
|
610 |
+
vae = self.vae if vae is None else vae
|
611 |
+
imgs = 2 * imgs - 1
|
612 |
+
latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator)
|
613 |
+
return latents
|
614 |
+
|
615 |
+
@torch.no_grad()
|
616 |
+
def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor:
|
617 |
+
r"""A wrapper function for VAE decoder of the latent diffusion model.
|
618 |
+
|
619 |
+
Args:
|
620 |
+
latents (torch.Tensor): An image latent to get associated images.
|
621 |
+
Expected shape: (B, 4, H//8, W//8).
|
622 |
+
vae (Optional[nn.Module]): Explicitly specify VAE (used for
|
623 |
+
the demo application with TinyVAE).
|
624 |
+
|
625 |
+
Returns:
|
626 |
+
An image latent embedding with 1/8 size (depending on the auto-
|
627 |
+
encoder. Shape: (B, 3, H, W).
|
628 |
+
"""
|
629 |
+
vae = self.vae if vae is None else vae
|
630 |
+
latents = 1 / vae.config.scaling_factor * latents
|
631 |
+
imgs = vae.decode(latents).sample
|
632 |
+
imgs = (imgs / 2 + 0.5).clip_(0, 1)
|
633 |
+
return imgs
|
634 |
+
|
635 |
+
@torch.no_grad()
|
636 |
+
def get_white_background(self, height: int, width: int) -> torch.Tensor:
|
637 |
+
r"""White background image latent for bootstrapping or in case of
|
638 |
+
absent background.
|
639 |
+
|
640 |
+
Additionally stores the maximally-sized white latent for fast retrieval
|
641 |
+
in the future. By default, we initially call this with 1024x1024 sized
|
642 |
+
white image, so the function is rarely visited twice.
|
643 |
+
|
644 |
+
Args:
|
645 |
+
height (int): The height of the white *image*, not its latent.
|
646 |
+
width (int): The width of the white *image*, not its latent.
|
647 |
+
|
648 |
+
Returns:
|
649 |
+
A white image latent of size (1, 4, height//8, width//8). A cropped
|
650 |
+
version of the stored white latent is returned if the requested
|
651 |
+
size is smaller than what we already have created.
|
652 |
+
"""
|
653 |
+
if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width:
|
654 |
+
white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
|
655 |
+
self.white = self.encode_imgs(white)
|
656 |
+
return self.white
|
657 |
+
return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)]
|
658 |
+
|
659 |
+
@torch.no_grad()
|
660 |
+
def process_mask(
|
661 |
+
self,
|
662 |
+
masks: Union[torch.Tensor, Image.Image, List[Image.Image]],
|
663 |
+
strength: Optional[Union[torch.Tensor, float]] = None,
|
664 |
+
std: Optional[Union[torch.Tensor, float]] = None,
|
665 |
+
height: int = 1024,
|
666 |
+
width: int = 1024,
|
667 |
+
use_boolean_mask: bool = True,
|
668 |
+
timesteps: Optional[torch.Tensor] = None,
|
669 |
+
preprocess_mask_cover_alpha: Optional[float] = None,
|
670 |
+
) -> Tuple[torch.Tensor]:
|
671 |
+
r"""Fast preprocess of masks for region-based generation with fine-
|
672 |
+
grained controls.
|
673 |
+
|
674 |
+
Mask preprocessing is done in four steps:
|
675 |
+
1. Resizing: Resize the masks into the specified width and height by
|
676 |
+
nearest neighbor interpolation.
|
677 |
+
2. (Optional) Ordering: Masks with higher indices are considered to
|
678 |
+
cover the masks with smaller indices. Covered masks are decayed
|
679 |
+
in its alpha value by the specified factor of
|
680 |
+
`preprocess_mask_cover_alpha`.
|
681 |
+
3. Blurring: Gaussian blur is applied to the mask with the specified
|
682 |
+
standard deviation (isotropic). This results in gradual increase of
|
683 |
+
masked region as the timesteps evolve, naturally blending fore-
|
684 |
+
ground and the predesignated background. Not strictly required if
|
685 |
+
you want to produce images from scratch withoout background.
|
686 |
+
4. Quantization: Split the real-numbered masks of value between [0, 1]
|
687 |
+
into predefined noise levels for each quantized scheduling step of
|
688 |
+
the diffusion sampler. For example, if the diffusion model sampler
|
689 |
+
has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
|
690 |
+
is the default noise level of this module with schedule [0, 4, 12,
|
691 |
+
25, 37], the masks are split into binary masks whose values are
|
692 |
+
greater than these levels. This results in tradual increase of mask
|
693 |
+
region as the timesteps increase. Details are described in our
|
694 |
+
paper at https://arxiv.org/pdf/2403.09055.pdf.
|
695 |
+
|
696 |
+
On the Three Modes of `mask_type`:
|
697 |
+
`self.mask_type` is predefined at the initialization stage of this
|
698 |
+
pipeline. Three possible modes are available: 'discrete', 'semi-
|
699 |
+
continuous', and 'continuous'. These define the mask quantization
|
700 |
+
modes we use. Basically, this (subtly) controls the smoothness of
|
701 |
+
foreground-background blending. Continuous modes produces nonbinary
|
702 |
+
masks to further blend foreground and background latents by linear-
|
703 |
+
ly interpolating between them. Semi-continuous masks only applies
|
704 |
+
continuous mask at the last step of the LCM sampler. Due to the
|
705 |
+
large step size of the LCM scheduler, we find that our continuous
|
706 |
+
blending helps generating seamless inpainting and editing results.
|
707 |
+
|
708 |
+
Args:
|
709 |
+
masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
|
710 |
+
strength (Optional[Union[torch.Tensor, float]]): Mask strength that
|
711 |
+
overrides the default value. A globally multiplied factor to
|
712 |
+
the mask at the initial stage of processing. Can be applied
|
713 |
+
seperately for each mask.
|
714 |
+
std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
|
715 |
+
kernel's standard deviation. Overrides the default value. Can
|
716 |
+
be applied seperately for each mask.
|
717 |
+
height (int): The height of the expected generation. Mask is
|
718 |
+
resized to (height//8, width//8) with nearest neighbor inter-
|
719 |
+
polation.
|
720 |
+
width (int): The width of the expected generation. Mask is resized
|
721 |
+
to (height//8, width//8) with nearest neighbor interpolation.
|
722 |
+
use_boolean_mask (bool): Specify this to treat the mask image as
|
723 |
+
a boolean tensor. The retion with dark part darker than 0.5 of
|
724 |
+
the maximal pixel value (that is, 127.5) is considered as the
|
725 |
+
designated mask.
|
726 |
+
timesteps (Optional[torch.Tensor]): Defines the scheduler noise
|
727 |
+
levels that acts as bins of mask quantization.
|
728 |
+
preprocess_mask_cover_alpha (Optional[float]): Optional pre-
|
729 |
+
processing where each mask covered by other masks is reduced in
|
730 |
+
its alpha value by this specified factor. Overrides the default
|
731 |
+
value.
|
732 |
+
|
733 |
+
Returns: A tuple of tensors.
|
734 |
+
- masks: Preprocessed (ordered, blurred, and quantized) binary/non-
|
735 |
+
binary masks (see the explanation on `mask_type` above) for
|
736 |
+
region-based image synthesis.
|
737 |
+
- masks_blurred: Gaussian blurred masks. Used for optionally
|
738 |
+
specified foreground-background blending after image
|
739 |
+
generation.
|
740 |
+
- std: Mask blur standard deviation. Used for optionally specified
|
741 |
+
foreground-background blending after image generation.
|
742 |
+
"""
|
743 |
+
if isinstance(masks, Image.Image):
|
744 |
+
masks = [masks]
|
745 |
+
if isinstance(masks, (tuple, list)):
|
746 |
+
# Assumes white background for Image.Image;
|
747 |
+
# inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
|
748 |
+
if use_boolean_mask:
|
749 |
+
proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5
|
750 |
+
else:
|
751 |
+
proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:]
|
752 |
+
masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1)
|
753 |
+
masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False)
|
754 |
+
masks = masks.to(self.device)
|
755 |
+
|
756 |
+
# Background mask alpha is decayed by the specified factor where foreground masks covers it.
|
757 |
+
if preprocess_mask_cover_alpha is None:
|
758 |
+
preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha
|
759 |
+
if preprocess_mask_cover_alpha > 0:
|
760 |
+
masks = torch.stack([
|
761 |
+
torch.where(
|
762 |
+
masks[i + 1:].sum(dim=0) > 0,
|
763 |
+
mask * preprocess_mask_cover_alpha,
|
764 |
+
mask,
|
765 |
+
) if i < len(masks) - 1 else mask
|
766 |
+
for i, mask in enumerate(masks)
|
767 |
+
], dim=0)
|
768 |
+
|
769 |
+
# Scheduler noise levels for mask quantization.
|
770 |
+
if timesteps is None:
|
771 |
+
noise_lvs = self.noise_lvs
|
772 |
+
next_noise_lvs = self.next_noise_lvs
|
773 |
+
else:
|
774 |
+
noise_lvs_ = self.sigmas * (self.sigmas**2 + 1)**(-0.5)
|
775 |
+
# noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5
|
776 |
+
noise_lvs = noise_lvs_[None, :, None, None, None].to(masks.device)
|
777 |
+
next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None]
|
778 |
+
|
779 |
+
# Mask preprocessing parameters are fetched from the default settings.
|
780 |
+
if std is None:
|
781 |
+
std = self.default_mask_std
|
782 |
+
if isinstance(std, (int, float)):
|
783 |
+
std = [std] * len(masks)
|
784 |
+
if isinstance(std, (list, tuple)):
|
785 |
+
std = torch.as_tensor(std, dtype=torch.float, device=self.device)
|
786 |
+
|
787 |
+
if strength is None:
|
788 |
+
strength = self.default_mask_strength
|
789 |
+
if isinstance(strength, (int, float)):
|
790 |
+
strength = [strength] * len(masks)
|
791 |
+
if isinstance(strength, (list, tuple)):
|
792 |
+
strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
|
793 |
+
|
794 |
+
if (std > 0).any():
|
795 |
+
std = torch.where(std > 0, std, 1e-5)
|
796 |
+
masks = gaussian_lowpass(masks, std)
|
797 |
+
masks_blurred = masks
|
798 |
+
|
799 |
+
# NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
|
800 |
+
# gives unpleasant results.
|
801 |
+
masks = masks * strength[:, None, None, None]
|
802 |
+
masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1)
|
803 |
+
|
804 |
+
# Mask is quantized according to the current noise levels specified by the scheduler.
|
805 |
+
if self.mask_type == 'discrete':
|
806 |
+
# Discrete mode.
|
807 |
+
masks = masks > noise_lvs
|
808 |
+
elif self.mask_type == 'semi-continuous':
|
809 |
+
# Semi-continuous mode (continuous at the last step only).
|
810 |
+
masks = torch.cat((
|
811 |
+
masks[:, :-1] > noise_lvs[:, :-1],
|
812 |
+
(
|
813 |
+
(masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:])
|
814 |
+
).clip_(0, 1),
|
815 |
+
), dim=1)
|
816 |
+
elif self.mask_type == 'continuous':
|
817 |
+
# Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
|
818 |
+
# decreases continuously after the discrete mode boundary to become `0` at the
|
819 |
+
# next lower threshold.
|
820 |
+
masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1)
|
821 |
+
|
822 |
+
# NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
|
823 |
+
# fine-grained mask alpha channel tuning is available with this form.
|
824 |
+
# masks = masks * strength[None, :, None, None, None]
|
825 |
+
|
826 |
+
h = height // self.vae_scale_factor
|
827 |
+
w = width // self.vae_scale_factor
|
828 |
+
masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
|
829 |
+
masks = F.interpolate(masks, size=(h, w), mode='nearest')
|
830 |
+
masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
|
831 |
+
return masks, masks_blurred, std
|
832 |
+
|
833 |
+
def scheduler_scale_model_input(
|
834 |
+
self,
|
835 |
+
latent: torch.FloatTensor,
|
836 |
+
idx: int,
|
837 |
+
) -> torch.FloatTensor:
|
838 |
+
"""
|
839 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
840 |
+
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
841 |
+
|
842 |
+
Args:
|
843 |
+
sample (`torch.FloatTensor`):
|
844 |
+
The input sample.
|
845 |
+
timestep (`int`, *optional*):
|
846 |
+
The current timestep in the diffusion chain.
|
847 |
+
|
848 |
+
Returns:
|
849 |
+
`torch.FloatTensor`:
|
850 |
+
A scaled input sample.
|
851 |
+
"""
|
852 |
+
latent = latent / ((self.sigmas[idx]**2 + 1) ** 0.5)
|
853 |
+
return latent
|
854 |
+
|
855 |
+
def scheduler_step(
|
856 |
+
self,
|
857 |
+
noise_pred: torch.Tensor,
|
858 |
+
idx: int,
|
859 |
+
latent: torch.Tensor,
|
860 |
+
) -> torch.Tensor:
|
861 |
+
r"""Denoise-only step for reverse diffusion scheduler.
|
862 |
+
|
863 |
+
Designed to match the interface of the original `pipe.scheduler.step`,
|
864 |
+
which is a combination of this method and the following
|
865 |
+
`scheduler_add_noise`.
|
866 |
+
|
867 |
+
Args:
|
868 |
+
noise_pred (torch.Tensor): Noise prediction results from the U-Net.
|
869 |
+
idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
|
870 |
+
for the timesteps tensor (ranged in [0, len(timesteps)-1]).
|
871 |
+
latent (torch.Tensor): Noisy latent.
|
872 |
+
|
873 |
+
Returns:
|
874 |
+
A denoised tensor with the same size as latent.
|
875 |
+
"""
|
876 |
+
# Upcast to avoid precision issues when computing prev_sample.
|
877 |
+
latent = latent.to(torch.float32)
|
878 |
+
|
879 |
+
# 1. Compute predicted original sample (x_0) from sigma-scaled predicted noise.
|
880 |
+
assert self.scheduler.config.prediction_type == 'epsilon', 'Only supports `prediction_type` of `epsilon` for now.'
|
881 |
+
# pred_original_sample = latent - self.sigma_hats[idx] * noise_pred
|
882 |
+
# prev_sample = pred_original_sample + noise_pred * (self.dt[i] + self.sigma_hats[i])
|
883 |
+
# return pred_original_sample.to(self.dtype)
|
884 |
+
|
885 |
+
# 2. Convert to an ODE derivative.
|
886 |
+
prev_sample = latent + noise_pred * self.dt[idx]
|
887 |
+
return prev_sample.to(self.dtype)
|
888 |
+
|
889 |
+
def scheduler_add_noise(
|
890 |
+
self,
|
891 |
+
latent: torch.Tensor,
|
892 |
+
noise: Optional[torch.Tensor],
|
893 |
+
idx: int,
|
894 |
+
s_noise: float = 1.0,
|
895 |
+
initial: bool = False,
|
896 |
+
) -> torch.Tensor:
|
897 |
+
r"""Separated noise-add step for the reverse diffusion scheduler.
|
898 |
+
|
899 |
+
Designed to match the interface of the original
|
900 |
+
`pipe.scheduler.add_noise`.
|
901 |
+
|
902 |
+
Args:
|
903 |
+
latent (torch.Tensor): Denoised latent.
|
904 |
+
noise (torch.Tensor): Added noise. Can be None. If None, a random
|
905 |
+
noise is newly sampled for addition.
|
906 |
+
idx (int): Instead of timesteps (in [0, 1000]-scale) use indices
|
907 |
+
for the timesteps tensor (ranged in [0, len(timesteps)-1]).
|
908 |
+
|
909 |
+
Returns:
|
910 |
+
A noisy tensor with the same size as latent.
|
911 |
+
"""
|
912 |
+
if initial:
|
913 |
+
if idx < len(self.sigmas) and idx >= 0:
|
914 |
+
noise = torch.randn_like(latent) if noise is None else noise
|
915 |
+
return latent + self.sigmas[idx] * noise
|
916 |
+
else:
|
917 |
+
return latent
|
918 |
+
else:
|
919 |
+
# 3. Post-add noise.
|
920 |
+
noise_lv = (self.sigma_hats[idx]**2 - self.sigmas[idx]**2) ** 0.5
|
921 |
+
if self.gammas[idx] > 0 and noise_lv > 0 and s_noise > 0 and idx < len(self.sigmas) and idx >= 0:
|
922 |
+
noise = torch.randn_like(latent) if noise is None else noise
|
923 |
+
eps = noise * s_noise * noise_lv
|
924 |
+
latent = latent + eps
|
925 |
+
# pred_original_sample = pred_original_sample + eps
|
926 |
+
return latent
|
927 |
+
|
928 |
+
@torch.no_grad()
|
929 |
+
def __call__(
|
930 |
+
self,
|
931 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
932 |
+
negative_prompts: Union[str, List[str]] = '',
|
933 |
+
suffix: Optional[str] = None, #', background is ',
|
934 |
+
background: Optional[Union[torch.Tensor, Image.Image]] = None,
|
935 |
+
background_prompt: Optional[str] = None,
|
936 |
+
background_negative_prompt: str = '',
|
937 |
+
height: int = 1024,
|
938 |
+
width: int = 1024,
|
939 |
+
num_inference_steps: Optional[int] = None,
|
940 |
+
guidance_scale: Optional[float] = None,
|
941 |
+
prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
942 |
+
masks: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
943 |
+
mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
944 |
+
mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
945 |
+
use_boolean_mask: bool = True,
|
946 |
+
do_blend: bool = True,
|
947 |
+
tile_size: int = 1024,
|
948 |
+
bootstrap_steps: Optional[int] = None,
|
949 |
+
boostrap_mix_steps: Optional[float] = None,
|
950 |
+
bootstrap_leak_sensitivity: Optional[float] = None,
|
951 |
+
preprocess_mask_cover_alpha: Optional[float] = None,
|
952 |
+
) -> Image.Image:
|
953 |
+
r"""Arbitrary-size image generation from multiple pairs of (regional)
|
954 |
+
text prompt-mask pairs.
|
955 |
+
|
956 |
+
This is a main routine for this pipeline.
|
957 |
+
|
958 |
+
Example:
|
959 |
+
>>> device = torch.device('cuda:0')
|
960 |
+
>>> smd = StableMultiDiffusionPipeline(device)
|
961 |
+
>>> prompts = {... specify prompts}
|
962 |
+
>>> masks = {... specify mask tensors}
|
963 |
+
>>> height, width = masks.shape[-2:]
|
964 |
+
>>> image = smd(
|
965 |
+
>>> prompts, masks=masks.float(), height=height, width=width)
|
966 |
+
>>> image.save('my_beautiful_creation.png')
|
967 |
+
|
968 |
+
Args:
|
969 |
+
prompts (Union[str, List[str]]): A text prompt.
|
970 |
+
negative_prompts (Union[str, List[str]]): A negative text prompt.
|
971 |
+
suffix (Optional[str]): One option for blending foreground prompts
|
972 |
+
with background prompts by simply appending background prompt
|
973 |
+
to the end of each foreground prompt with this `middle word` in
|
974 |
+
between. For example, if you set this as `, background is`,
|
975 |
+
then the foreground prompt will be changed into
|
976 |
+
`(fg), background is (bg)` before conditional generation.
|
977 |
+
background (Optional[Union[torch.Tensor, Image.Image]]): a
|
978 |
+
background image, if the user wants to draw in front of the
|
979 |
+
specified image. Background prompt will automatically generated
|
980 |
+
with a BLIP-2 model.
|
981 |
+
background_prompt (Optional[str]): The background prompt is used
|
982 |
+
for preprocessing foreground prompt embeddings to blend
|
983 |
+
foreground and background.
|
984 |
+
background_negative_prompt (Optional[str]): The negative background
|
985 |
+
prompt.
|
986 |
+
height (int): Height of a generated image. It is tiled if larger
|
987 |
+
than `tile_size`.
|
988 |
+
width (int): Width of a generated image. It is tiled if larger
|
989 |
+
than `tile_size`.
|
990 |
+
num_inference_steps (Optional[int]): Number of inference steps.
|
991 |
+
Default inference scheduling is used if none is specified.
|
992 |
+
guidance_scale (Optional[float]): Classifier guidance scale.
|
993 |
+
Default value is used if none is specified.
|
994 |
+
prompt_strength (float): Overrides default value. Preprocess
|
995 |
+
foreground prompts globally by linearly interpolating its
|
996 |
+
embedding with the background prompt embeddint with specified
|
997 |
+
mix ratio. Useful control handle for foreground blending.
|
998 |
+
Recommended range: 0.5-1.
|
999 |
+
masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of
|
1000 |
+
mask images. Each mask associates with each of the text prompts
|
1001 |
+
and each of the negative prompts. If specified as an image, it
|
1002 |
+
regards the image as a boolean mask. Also accepts torch.Tensor
|
1003 |
+
masks, which can have nonbinary values for fine-grained
|
1004 |
+
controls in mixing regional generations.
|
1005 |
+
mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]):
|
1006 |
+
Overrides the default value. an be assigned for each mask
|
1007 |
+
separately. Preprocess mask by multiplying it globally with the
|
1008 |
+
specified variable. Caution: extremely sensitive. Recommended
|
1009 |
+
range: 0.98-1.
|
1010 |
+
mask_stds (Optional[Union[torch.Tensor, float, List[float]]]):
|
1011 |
+
Overrides the default value. Can be assigned for each mask
|
1012 |
+
separately. Preprocess mask with Gaussian blur with specified
|
1013 |
+
standard deviation. Recommended range: 0-64.
|
1014 |
+
use_boolean_mask (bool): Turn this off if you want to treat the
|
1015 |
+
mask image as nonbinary one. The module will use the last
|
1016 |
+
channel of the given image in `masks` as the mask value.
|
1017 |
+
do_blend (bool): Blend the generated foreground and the optionally
|
1018 |
+
predefined background by smooth boundary obtained from Gaussian
|
1019 |
+
blurs of the foreground `masks` with the given `mask_stds`.
|
1020 |
+
tile_size (Optional[int]): Tile size of the panorama generation.
|
1021 |
+
Works best with the default training size of the Stable-
|
1022 |
+
Diffusion model, i.e., 1024 or 1024 for SD1.5 and 1024 for SDXL.
|
1023 |
+
bootstrap_steps (int): Overrides the default value. Bootstrapping
|
1024 |
+
stage steps to encourage region separation. Recommended range:
|
1025 |
+
1-3.
|
1026 |
+
boostrap_mix_steps (float): Overrides the default value.
|
1027 |
+
Bootstrapping background is a linear interpolation between
|
1028 |
+
background latent and the white image latent. This handle
|
1029 |
+
controls the mix ratio. Available range: 0-(number of
|
1030 |
+
bootstrapping inference steps). For example, 2.3 means that for
|
1031 |
+
the first two steps, white image is used as a bootstrapping
|
1032 |
+
background and in the third step, mixture of white (0.3) and
|
1033 |
+
registered background (0.7) is used as a bootstrapping
|
1034 |
+
background.
|
1035 |
+
bootstrap_leak_sensitivity (float): Overrides the default value.
|
1036 |
+
Postprocessing at each inference step by masking away the
|
1037 |
+
remaining bootstrap backgrounds t Recommended range: 0-1.
|
1038 |
+
preprocess_mask_cover_alpha (float): Overrides the default value.
|
1039 |
+
Optional preprocessing where each mask covered by other masks
|
1040 |
+
is reduced in its alpha value by this specified factor.
|
1041 |
+
|
1042 |
+
Returns: A PIL.Image image of a panorama (large-size) image.
|
1043 |
+
"""
|
1044 |
+
|
1045 |
+
### Simplest cases
|
1046 |
+
|
1047 |
+
# prompts is None: return background.
|
1048 |
+
# masks is None but prompts is not None: return prompts
|
1049 |
+
# masks is not None and prompts is not None: Do StableMultiDiffusion.
|
1050 |
+
|
1051 |
+
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
1052 |
+
if background is None and background_prompt is not None:
|
1053 |
+
return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale)
|
1054 |
+
return background
|
1055 |
+
elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0):
|
1056 |
+
return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale)
|
1057 |
+
|
1058 |
+
|
1059 |
+
### Prepare generation
|
1060 |
+
|
1061 |
+
if num_inference_steps is not None:
|
1062 |
+
# self.prepare_lcm_schedule(list(range(num_inference_steps)), num_inference_steps)
|
1063 |
+
self.prepare_lightning_schedule(list(range(num_inference_steps)), num_inference_steps)
|
1064 |
+
|
1065 |
+
if guidance_scale is None:
|
1066 |
+
guidance_scale = self.default_guidance_scale
|
1067 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
1068 |
+
|
1069 |
+
|
1070 |
+
### Prompts & Masks
|
1071 |
+
|
1072 |
+
# asserts #m > 0 and #p > 0.
|
1073 |
+
# #m == #p == #n > 0: We happily generate according to the prompts & masks.
|
1074 |
+
# #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks.
|
1075 |
+
# #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts.
|
1076 |
+
|
1077 |
+
if isinstance(masks, Image.Image):
|
1078 |
+
masks = [masks]
|
1079 |
+
if isinstance(prompts, str):
|
1080 |
+
prompts = [prompts]
|
1081 |
+
if isinstance(negative_prompts, str):
|
1082 |
+
negative_prompts = [negative_prompts]
|
1083 |
+
num_masks = len(masks)
|
1084 |
+
num_prompts = len(prompts)
|
1085 |
+
num_nprompts = len(negative_prompts)
|
1086 |
+
assert num_prompts in (num_masks, 1), \
|
1087 |
+
f'The number of prompts {num_prompts} should match the number of masks {num_masks}!'
|
1088 |
+
assert num_nprompts in (num_prompts, 1), \
|
1089 |
+
f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!'
|
1090 |
+
|
1091 |
+
fg_masks, masks_g, std = self.process_mask(
|
1092 |
+
masks,
|
1093 |
+
mask_strengths,
|
1094 |
+
mask_stds,
|
1095 |
+
height=height,
|
1096 |
+
width=width,
|
1097 |
+
use_boolean_mask=use_boolean_mask,
|
1098 |
+
timesteps=self.timesteps,
|
1099 |
+
preprocess_mask_cover_alpha=preprocess_mask_cover_alpha,
|
1100 |
+
) # (p, t, 1, H, W)
|
1101 |
+
bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1) # (T, 1, h, w)
|
1102 |
+
has_background = bg_masks.sum() > 0
|
1103 |
+
|
1104 |
+
h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor
|
1105 |
+
w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor
|
1106 |
+
|
1107 |
+
|
1108 |
+
### Background
|
1109 |
+
|
1110 |
+
# background == None && background_prompt == None: Initialize with white background.
|
1111 |
+
# background == None && background_prompt != None: Generate background *along with other prompts*.
|
1112 |
+
# background != None && background_prompt == None: Retrieve text prompt using BLIP.
|
1113 |
+
# background != None && background_prompt != None: Use the given arguments.
|
1114 |
+
|
1115 |
+
# not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt)
|
1116 |
+
# has_background && prompt_strength != 1: mix only for this case.
|
1117 |
+
|
1118 |
+
bg_latent = None
|
1119 |
+
if has_background:
|
1120 |
+
if background is None and background_prompt is not None:
|
1121 |
+
fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0)
|
1122 |
+
if suffix is not None:
|
1123 |
+
prompts = [p + suffix + background_prompt for p in prompts]
|
1124 |
+
prompts = [background_prompt] + prompts
|
1125 |
+
negative_prompts = [background_negative_prompt] + negative_prompts
|
1126 |
+
has_background = False # Regard that background does not exist.
|
1127 |
+
else:
|
1128 |
+
if background is None and background_prompt is None:
|
1129 |
+
background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device)
|
1130 |
+
background_prompt = 'simple white background image'
|
1131 |
+
elif background is not None and background_prompt is None:
|
1132 |
+
background_prompt = self.get_text_prompts(background)
|
1133 |
+
if suffix is not None:
|
1134 |
+
prompts = [p + suffix + background_prompt for p in prompts]
|
1135 |
+
prompts = [background_prompt] + prompts
|
1136 |
+
negative_prompts = [background_negative_prompt] + negative_prompts
|
1137 |
+
if isinstance(background, Image.Image):
|
1138 |
+
background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None]
|
1139 |
+
background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False)
|
1140 |
+
bg_latent = self.encode_imgs(background)
|
1141 |
+
|
1142 |
+
# Bootstrapping stage preparation.
|
1143 |
+
|
1144 |
+
if bootstrap_steps is None:
|
1145 |
+
bootstrap_steps = self.default_bootstrap_steps
|
1146 |
+
if boostrap_mix_steps is None:
|
1147 |
+
boostrap_mix_steps = self.default_boostrap_mix_steps
|
1148 |
+
if bootstrap_leak_sensitivity is None:
|
1149 |
+
bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity
|
1150 |
+
if bootstrap_steps > 0:
|
1151 |
+
height_ = min(height, tile_size)
|
1152 |
+
width_ = min(width, tile_size)
|
1153 |
+
white = self.get_white_background(height, width) # (1, 4, h, w)
|
1154 |
+
|
1155 |
+
|
1156 |
+
### Prepare text embeddings (optimized for the minimal encoder batch size)
|
1157 |
+
|
1158 |
+
# SDXL pipeline settings.
|
1159 |
+
batch_size = 1
|
1160 |
+
output_type = 'pil'
|
1161 |
+
|
1162 |
+
guidance_rescale = 0.7
|
1163 |
+
|
1164 |
+
prompt_2 = None
|
1165 |
+
device = self.device
|
1166 |
+
num_images_per_prompt = 1
|
1167 |
+
negative_prompt_2 = None
|
1168 |
+
|
1169 |
+
original_size = (height, width)
|
1170 |
+
target_size = (height, width)
|
1171 |
+
crops_coords_top_left = (0, 0)
|
1172 |
+
negative_crops_coords_top_left = (0, 0)
|
1173 |
+
negative_original_size = None
|
1174 |
+
negative_target_size = None
|
1175 |
+
pooled_prompt_embeds = None
|
1176 |
+
negative_pooled_prompt_embeds = None
|
1177 |
+
text_encoder_lora_scale = None
|
1178 |
+
|
1179 |
+
prompt_embeds = None
|
1180 |
+
negative_prompt_embeds = None
|
1181 |
+
|
1182 |
+
(
|
1183 |
+
prompt_embeds,
|
1184 |
+
negative_prompt_embeds,
|
1185 |
+
pooled_prompt_embeds,
|
1186 |
+
negative_pooled_prompt_embeds,
|
1187 |
+
) = self.encode_prompt(
|
1188 |
+
prompt=prompts,
|
1189 |
+
prompt_2=prompt_2,
|
1190 |
+
device=device,
|
1191 |
+
num_images_per_prompt=num_images_per_prompt,
|
1192 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
1193 |
+
negative_prompt=negative_prompts,
|
1194 |
+
negative_prompt_2=negative_prompt_2,
|
1195 |
+
prompt_embeds=prompt_embeds,
|
1196 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1197 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1198 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1199 |
+
lora_scale=text_encoder_lora_scale,
|
1200 |
+
)
|
1201 |
+
|
1202 |
+
add_text_embeds = pooled_prompt_embeds
|
1203 |
+
if self.text_encoder_2 is None:
|
1204 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1205 |
+
else:
|
1206 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1207 |
+
|
1208 |
+
add_time_ids = self._get_add_time_ids(
|
1209 |
+
original_size,
|
1210 |
+
crops_coords_top_left,
|
1211 |
+
target_size,
|
1212 |
+
dtype=prompt_embeds.dtype,
|
1213 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1214 |
+
)
|
1215 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1216 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1217 |
+
negative_original_size,
|
1218 |
+
negative_crops_coords_top_left,
|
1219 |
+
negative_target_size,
|
1220 |
+
dtype=prompt_embeds.dtype,
|
1221 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1222 |
+
)
|
1223 |
+
else:
|
1224 |
+
negative_add_time_ids = add_time_ids
|
1225 |
+
|
1226 |
+
if has_background:
|
1227 |
+
# First channel is background prompt text embeds. Background prompt itself is not used for generation.
|
1228 |
+
s = prompt_strengths
|
1229 |
+
if prompt_strengths is None:
|
1230 |
+
s = self.default_prompt_strength
|
1231 |
+
if isinstance(s, (int, float)):
|
1232 |
+
s = [s] * num_prompts
|
1233 |
+
if isinstance(s, (list, tuple)):
|
1234 |
+
assert len(s) == num_prompts, \
|
1235 |
+
f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!'
|
1236 |
+
s = torch.as_tensor(s, dtype=self.dtype, device=self.device)
|
1237 |
+
s = s[:, None, None]
|
1238 |
+
|
1239 |
+
be = prompt_embeds[:1]
|
1240 |
+
fe = prompt_embeds[1:]
|
1241 |
+
prompt_embeds = torch.lerp(be, fe, s) # (p, 77, 1024)
|
1242 |
+
|
1243 |
+
if negative_prompt_embeds is not None:
|
1244 |
+
bu = negative_prompt_embeds[:1]
|
1245 |
+
fu = negative_prompt_embeds[1:]
|
1246 |
+
if num_prompts > num_nprompts:
|
1247 |
+
# # negative prompts = 1; # prompts > 1.
|
1248 |
+
assert fu.shape[0] == 1 and fe.shape == num_prompts
|
1249 |
+
fu = fu.repeat(num_prompts, 1, 1)
|
1250 |
+
negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024)
|
1251 |
+
elif negative_prompt_embeds is not None and num_prompts > num_nprompts:
|
1252 |
+
# # negative prompts = 1; # prompts > 1.
|
1253 |
+
assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts
|
1254 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1)
|
1255 |
+
# assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts
|
1256 |
+
if num_masks > num_prompts:
|
1257 |
+
assert masks.shape[0] == num_masks and num_prompts == 1
|
1258 |
+
prompt_embeds = prompt_embeds.repeat(num_masks, 1, 1)
|
1259 |
+
if negative_prompt_embeds is not None:
|
1260 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1)
|
1261 |
+
|
1262 |
+
# SDXL pipeline settings.
|
1263 |
+
if do_classifier_free_guidance:
|
1264 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1265 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
1266 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1267 |
+
del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
|
1268 |
+
|
1269 |
+
prompt_embeds = prompt_embeds.to(device)
|
1270 |
+
add_text_embeds = add_text_embeds.to(device)
|
1271 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1272 |
+
|
1273 |
+
|
1274 |
+
### Run
|
1275 |
+
|
1276 |
+
# Latent initialization.
|
1277 |
+
if self.timesteps[0] < 999 and has_background:
|
1278 |
+
latents = self.scheduler_add_noise(bg_latents, None, 0, initial=True)
|
1279 |
+
else:
|
1280 |
+
latents = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device)
|
1281 |
+
latents = latents * self.scheduler.init_noise_sigma
|
1282 |
+
|
1283 |
+
# Tiling (if needed).
|
1284 |
+
if height > tile_size or width > tile_size:
|
1285 |
+
t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor
|
1286 |
+
views, tile_masks = get_panorama_views(h, w, t)
|
1287 |
+
tile_masks = tile_masks.to(self.device)
|
1288 |
+
else:
|
1289 |
+
views = [(0, h, 0, w)]
|
1290 |
+
tile_masks = latents.new_ones((1, 1, h, w))
|
1291 |
+
value = torch.zeros_like(latents)
|
1292 |
+
count_all = torch.zeros_like(latents)
|
1293 |
+
|
1294 |
+
with torch.autocast('cuda'):
|
1295 |
+
for i, t in enumerate(tqdm(self.timesteps)):
|
1296 |
+
fg_mask = fg_masks[:, i]
|
1297 |
+
bg_mask = bg_masks[i:i + 1]
|
1298 |
+
|
1299 |
+
value.zero_()
|
1300 |
+
count_all.zero_()
|
1301 |
+
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
|
1302 |
+
fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end]
|
1303 |
+
latents_ = latents[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1)
|
1304 |
+
|
1305 |
+
# Additional arguments for the SDXL pipeline.
|
1306 |
+
add_time_ids_input = add_time_ids.clone()
|
1307 |
+
add_time_ids_input[:, 2] = h_start * self.vae_scale_factor
|
1308 |
+
add_time_ids_input[:, 3] = w_start * self.vae_scale_factor
|
1309 |
+
add_time_ids_input = add_time_ids_input.repeat_interleave(num_prompts, dim=0)
|
1310 |
+
|
1311 |
+
# Bootstrap for tight background.
|
1312 |
+
if i < bootstrap_steps:
|
1313 |
+
mix_ratio = min(1, max(0, boostrap_mix_steps - i))
|
1314 |
+
# Treat the first foreground latent as the background latent if one does not exist.
|
1315 |
+
bg_latents_ = bg_latents[..., h_start:h_end, w_start:w_end] if has_background else latents_[:1]
|
1316 |
+
white_ = white[..., h_start:h_end, w_start:w_end]
|
1317 |
+
white_ = self.scheduler_add_noise(white_, None, i, initial=True)
|
1318 |
+
bg_latents_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latents_
|
1319 |
+
latents_ = (1.0 - fg_mask_) * bg_latents_ + fg_mask_ * latents_
|
1320 |
+
|
1321 |
+
# Centering.
|
1322 |
+
latents_ = shift_to_mask_bbox_center(latents_, fg_mask_, reverse=True)
|
1323 |
+
|
1324 |
+
latent_model_input = torch.cat([latents_] * 2) if do_classifier_free_guidance else latents_
|
1325 |
+
latent_model_input = self.scheduler_scale_model_input(latent_model_input, i)
|
1326 |
+
|
1327 |
+
# Perform one step of the reverse diffusion.
|
1328 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_input}
|
1329 |
+
noise_pred = self.unet(
|
1330 |
+
latent_model_input,
|
1331 |
+
t,
|
1332 |
+
encoder_hidden_states=prompt_embeds,
|
1333 |
+
timestep_cond=None,
|
1334 |
+
cross_attention_kwargs=None,
|
1335 |
+
added_cond_kwargs=added_cond_kwargs,
|
1336 |
+
return_dict=False,
|
1337 |
+
)[0]
|
1338 |
+
|
1339 |
+
if do_classifier_free_guidance:
|
1340 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
1341 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
1342 |
+
|
1343 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
1344 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1345 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale)
|
1346 |
+
|
1347 |
+
latents_ = self.scheduler_step(noise_pred, i, latents_)
|
1348 |
+
|
1349 |
+
if i < bootstrap_steps:
|
1350 |
+
# Uncentering.
|
1351 |
+
latents_ = shift_to_mask_bbox_center(latents_, fg_mask_)
|
1352 |
+
|
1353 |
+
# Remove leakage (optional).
|
1354 |
+
leak = (latents_ - bg_latents_).pow(2).mean(dim=1, keepdim=True)
|
1355 |
+
leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1
|
1356 |
+
fg_mask_ = fg_mask_ * leak_sigmoid
|
1357 |
+
|
1358 |
+
# Mix the latents.
|
1359 |
+
fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end]
|
1360 |
+
value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latents_).sum(dim=0, keepdim=True)
|
1361 |
+
count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True)
|
1362 |
+
|
1363 |
+
latents = torch.where(count_all > 0, value / count_all, value)
|
1364 |
+
bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w)
|
1365 |
+
if has_background:
|
1366 |
+
latents = (1 - bg_mask) * latents + bg_mask * bg_latents
|
1367 |
+
|
1368 |
+
# Noise is added after mixing.
|
1369 |
+
if i < len(self.timesteps) - 1:
|
1370 |
+
latents = self.scheduler_add_noise(latents, None, i + 1)
|
1371 |
+
|
1372 |
+
if not output_type == "latent":
|
1373 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1374 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1375 |
+
|
1376 |
+
if needs_upcasting:
|
1377 |
+
self.upcast_vae()
|
1378 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1379 |
+
|
1380 |
+
# unscale/denormalize the latents
|
1381 |
+
# denormalize with the mean and std if available and not None
|
1382 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
1383 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
1384 |
+
if has_latents_mean and has_latents_std:
|
1385 |
+
latents_mean = (
|
1386 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
1387 |
+
)
|
1388 |
+
latents_std = (
|
1389 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
1390 |
+
)
|
1391 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
1392 |
+
else:
|
1393 |
+
latents = latents / self.vae.config.scaling_factor
|
1394 |
+
|
1395 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1396 |
+
|
1397 |
+
# cast back to fp16 if needed
|
1398 |
+
if needs_upcasting:
|
1399 |
+
self.vae.to(dtype=torch.float16)
|
1400 |
+
else:
|
1401 |
+
image = latents
|
1402 |
+
|
1403 |
+
# Return PIL Image.
|
1404 |
+
image = image[0].clip_(-1, 1) * 0.5 + 0.5
|
1405 |
+
if has_background and do_blend:
|
1406 |
+
fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1)
|
1407 |
+
image = blend(image, background[0], fg_mask)
|
1408 |
+
else:
|
1409 |
+
image = T.ToPILImage()(image)
|
1410 |
+
return image
|
prompt_util.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Union
|
2 |
+
|
3 |
+
|
4 |
+
quality_prompt_list = [
|
5 |
+
{
|
6 |
+
"name": "(None)",
|
7 |
+
"prompt": "{prompt}",
|
8 |
+
"negative_prompt": "nsfw, lowres",
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"name": "Standard v3.0",
|
12 |
+
"prompt": "{prompt}, masterpiece, best quality",
|
13 |
+
"negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "Standard v3.1",
|
17 |
+
"prompt": "{prompt}, masterpiece, best quality, very aesthetic, absurdres",
|
18 |
+
"negative_prompt": "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"name": "Light v3.1",
|
22 |
+
"prompt": "{prompt}, (masterpiece), best quality, very aesthetic, perfect face",
|
23 |
+
"negative_prompt": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "Heavy v3.1",
|
27 |
+
"prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
|
28 |
+
"negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
|
29 |
+
},
|
30 |
+
]
|
31 |
+
|
32 |
+
style_list = [
|
33 |
+
{
|
34 |
+
"name": "(None)",
|
35 |
+
"prompt": "{prompt}",
|
36 |
+
"negative_prompt": "",
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"name": "Cinematic",
|
40 |
+
"prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
41 |
+
"negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "Photographic",
|
45 |
+
"prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
46 |
+
"negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"name": "Anime",
|
50 |
+
"prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
|
51 |
+
"negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"name": "Manga",
|
55 |
+
"prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
|
56 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"name": "Digital Art",
|
60 |
+
"prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
|
61 |
+
"negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"name": "Pixel art",
|
65 |
+
"prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
|
66 |
+
"negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"name": "Fantasy art",
|
70 |
+
"prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
71 |
+
"negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"name": "Neonpunk",
|
75 |
+
"prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
76 |
+
"negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"name": "3D Model",
|
80 |
+
"prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
|
81 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
|
82 |
+
},
|
83 |
+
]
|
84 |
+
|
85 |
+
|
86 |
+
_style_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
87 |
+
_quality_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
|
88 |
+
|
89 |
+
|
90 |
+
def preprocess_prompt(
|
91 |
+
positive: str,
|
92 |
+
negative: str = "",
|
93 |
+
style_dict: Dict[str, dict] = _quality_dict,
|
94 |
+
style_name: str = "Standard v3.1", # "Heavy v3.1"
|
95 |
+
add_style: bool = True,
|
96 |
+
) -> Tuple[str, str]:
|
97 |
+
p, n = style_dict.get(style_name, style_dict["(None)"])
|
98 |
+
|
99 |
+
if add_style and positive.strip():
|
100 |
+
formatted_positive = p.format(prompt=positive)
|
101 |
+
else:
|
102 |
+
formatted_positive = positive
|
103 |
+
|
104 |
+
combined_negative = n
|
105 |
+
if negative.strip():
|
106 |
+
if combined_negative:
|
107 |
+
combined_negative += ", " + negative
|
108 |
+
else:
|
109 |
+
combined_negative = negative
|
110 |
+
|
111 |
+
return formatted_positive, combined_negative
|
112 |
+
|
113 |
+
|
114 |
+
def preprocess_prompts(
|
115 |
+
positives: List[str],
|
116 |
+
negatives: List[str] = None,
|
117 |
+
style_dict = _style_dict,
|
118 |
+
style_name: str = "Manga", # "(None)"
|
119 |
+
quality_dict = _quality_dict,
|
120 |
+
quality_name: str = "Standard v3.1", # "Heavy v3.1"
|
121 |
+
add_style: bool = True,
|
122 |
+
add_quality_tags = True,
|
123 |
+
) -> Tuple[List[str], List[str]]:
|
124 |
+
if negatives is None:
|
125 |
+
negatives = ['' for _ in positives]
|
126 |
+
|
127 |
+
positives_ = []
|
128 |
+
negatives_ = []
|
129 |
+
for pos, neg in zip(positives, negatives):
|
130 |
+
pos, neg = preprocess_prompt(pos, neg, quality_dict, quality_name, add_quality_tags)
|
131 |
+
pos, neg = preprocess_prompt(pos, neg, style_dict, style_name, add_style)
|
132 |
+
positives_.append(pos)
|
133 |
+
negatives_.append(neg)
|
134 |
+
return positives_, negatives_
|
135 |
+
|
136 |
+
|
137 |
+
def print_prompts(
|
138 |
+
positives: Union[str, List[str]],
|
139 |
+
negatives: Union[str, List[str]],
|
140 |
+
has_background: bool = False,
|
141 |
+
) -> None:
|
142 |
+
if isinstance(positives, str):
|
143 |
+
positives = [positives]
|
144 |
+
if isinstance(negatives, str):
|
145 |
+
negatives = [negatives]
|
146 |
+
|
147 |
+
for i, prompt in enumerate(positives):
|
148 |
+
prefix = ((f'Prompt{i}' if i > 0 else 'Background Prompt')
|
149 |
+
if has_background else f'Prompt{i + 1}')
|
150 |
+
print(prefix + ': ' + prompt)
|
151 |
+
for i, prompt in enumerate(negatives):
|
152 |
+
prefix = ((f'Negative Prompt{i}' if i > 0 else 'Background Negative Prompt')
|
153 |
+
if has_background else f'Negative Prompt{i + 1}')
|
154 |
+
print(prefix + ': ' + prompt)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision
|
3 |
+
xformers==0.0.22
|
4 |
+
einops
|
5 |
+
diffusers
|
6 |
+
transformers
|
7 |
+
huggingface_hub[torch]
|
8 |
+
gradio
|
9 |
+
Pillow
|
10 |
+
emoji
|
11 |
+
numpy
|
12 |
+
tqdm
|
13 |
+
jupyterlab
|
14 |
+
spaces
|
share_btn.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
|
2 |
+
<path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
|
3 |
+
<path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
|
4 |
+
</svg>"""
|
5 |
+
|
6 |
+
loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
|
7 |
+
style="color: #ffffff;
|
8 |
+
"
|
9 |
+
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
10 |
+
|
11 |
+
share_js = """async () => {
|
12 |
+
async function uploadFile(file){
|
13 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
14 |
+
const response = await fetch(UPLOAD_URL, {
|
15 |
+
method: 'POST',
|
16 |
+
headers: {
|
17 |
+
'Content-Type': file.type,
|
18 |
+
'X-Requested-With': 'XMLHttpRequest',
|
19 |
+
},
|
20 |
+
body: file, /// <- File inherits from Blob
|
21 |
+
});
|
22 |
+
const url = await response.text();
|
23 |
+
return url;
|
24 |
+
}
|
25 |
+
const gradioEl = document.querySelector('body > gradio-app');
|
26 |
+
const imgEls = gradioEl.querySelectorAll('#output-screen img');
|
27 |
+
const shareBtnEl = gradioEl.querySelector('#share-btn');
|
28 |
+
const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
|
29 |
+
const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
|
30 |
+
if(!imgEls.length){
|
31 |
+
return;
|
32 |
+
};
|
33 |
+
shareBtnEl.style.pointerEvents = 'none';
|
34 |
+
shareIconEl.style.display = 'none';
|
35 |
+
loadingIconEl.style.removeProperty('display');
|
36 |
+
const files = await Promise.all(
|
37 |
+
[...imgEls].map(async (imgEl) => {
|
38 |
+
const res = await fetch(imgEl.src);
|
39 |
+
const blob = await res.blob();
|
40 |
+
const imgId = Date.now() % 200;
|
41 |
+
const fileName = `diffuse-the-rest-${{imgId}}.jpg`;
|
42 |
+
return new File([blob], fileName, { type: 'image/jpeg' });
|
43 |
+
})
|
44 |
+
);
|
45 |
+
const urls = await Promise.all(files.map((f) => uploadFile(f)));
|
46 |
+
const htmlImgs = urls.map(url => `<img src='${url}' width='2560' height='1024'>`);
|
47 |
+
const descriptionMd = `<div style='display: flex; flex-wrap: wrap; column-gap: 0.75rem;'>
|
48 |
+
${htmlImgs.join(`\n`)}
|
49 |
+
</div>`;
|
50 |
+
const params = new URLSearchParams({
|
51 |
+
title: <p>My creation</p>,
|
52 |
+
description: descriptionMd,
|
53 |
+
});
|
54 |
+
const paramsStr = params.toString();
|
55 |
+
window.open(`https://huggingface.co/spaces/ironjr/SemanticPaletteXL/discussions/new?${paramsStr}`, '_blank');
|
56 |
+
shareBtnEl.style.removeProperty('pointer-events');
|
57 |
+
shareIconEl.style.removeProperty('display');
|
58 |
+
loadingIconEl.style.display = 'none';
|
59 |
+
}"""
|
util.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import concurrent.futures
|
22 |
+
import time
|
23 |
+
from typing import Any, Callable, List, Literal, Tuple, Union
|
24 |
+
|
25 |
+
from PIL import Image
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
import torch.cuda.amp as amp
|
31 |
+
import torchvision.transforms as T
|
32 |
+
import torchvision.transforms.functional as TF
|
33 |
+
|
34 |
+
from diffusers import (
|
35 |
+
DiffusionPipeline,
|
36 |
+
StableDiffusionPipeline,
|
37 |
+
StableDiffusionXLPipeline,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def seed_everything(seed: int) -> None:
|
42 |
+
torch.manual_seed(seed)
|
43 |
+
torch.cuda.manual_seed(seed)
|
44 |
+
torch.backends.cudnn.deterministic = True
|
45 |
+
torch.backends.cudnn.benchmark = True
|
46 |
+
|
47 |
+
|
48 |
+
def load_model(
|
49 |
+
model_key: str,
|
50 |
+
sd_version: Literal['1.5', 'xl'],
|
51 |
+
device: torch.device,
|
52 |
+
dtype: torch.dtype,
|
53 |
+
) -> torch.nn.Module:
|
54 |
+
if model_key.endswith('.safetensors'):
|
55 |
+
if sd_version == '1.5':
|
56 |
+
pipeline = StableDiffusionPipeline
|
57 |
+
elif sd_version == 'xl':
|
58 |
+
pipeline = StableDiffusionXLPipeline
|
59 |
+
else:
|
60 |
+
raise ValueError(f'Stable Diffusion version {sd_version} not supported.')
|
61 |
+
return pipeline.from_single_file(model_key, torch_dtype=dtype).to(device)
|
62 |
+
try:
|
63 |
+
return DiffusionPipeline.from_pretrained(model_key, variant='fp16', torch_dtype=dtype).to(device)
|
64 |
+
except:
|
65 |
+
return DiffusionPipeline.from_pretrained(model_key, variant=None, torch_dtype=dtype).to(device)
|
66 |
+
|
67 |
+
|
68 |
+
def get_cutoff(cutoff: float = None, scale: float = None) -> float:
|
69 |
+
if cutoff is not None:
|
70 |
+
return cutoff
|
71 |
+
|
72 |
+
if scale is not None and cutoff is None:
|
73 |
+
return 0.5 / scale
|
74 |
+
|
75 |
+
raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
|
76 |
+
|
77 |
+
|
78 |
+
def get_scale(cutoff: float = None, scale: float = None) -> float:
|
79 |
+
if scale is not None:
|
80 |
+
return scale
|
81 |
+
|
82 |
+
if cutoff is not None and scale is None:
|
83 |
+
return 0.5 / cutoff
|
84 |
+
|
85 |
+
raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
|
86 |
+
|
87 |
+
|
88 |
+
def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
89 |
+
assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
|
90 |
+
# assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
|
91 |
+
|
92 |
+
b, c, h, w = x.shape
|
93 |
+
ks = k.shape[-1]
|
94 |
+
k = k.view(1, 1, -1).repeat(c, 1, 1)
|
95 |
+
|
96 |
+
x = x.permute(0, 2, 1, 3)
|
97 |
+
x = x.reshape(b * h, c, w)
|
98 |
+
x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
|
99 |
+
x = F.conv1d(x, k, groups=c)
|
100 |
+
x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
|
101 |
+
x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
|
102 |
+
x = F.conv1d(x, k, groups=c)
|
103 |
+
x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
108 |
+
assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
|
109 |
+
|
110 |
+
x = F.pad(x, (
|
111 |
+
k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
|
112 |
+
k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
|
113 |
+
), mode='replicate')
|
114 |
+
|
115 |
+
b, c, _, _ = x.shape
|
116 |
+
if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
|
117 |
+
k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
|
118 |
+
x = F.conv2d(x, k, groups=c)
|
119 |
+
elif len(k.shape) == 3:
|
120 |
+
assert k.shape[0] == b, \
|
121 |
+
'The number of kernels should match the batch size.'
|
122 |
+
|
123 |
+
k = k.unsqueeze(1)
|
124 |
+
x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
@amp.autocast(False)
|
129 |
+
def filter_by_kernel(
|
130 |
+
x: torch.Tensor,
|
131 |
+
k: torch.Tensor,
|
132 |
+
is_batch: bool = False,
|
133 |
+
) -> torch.Tensor:
|
134 |
+
k_dim = len(k.shape)
|
135 |
+
if k_dim == 1 or k_dim == 2 and is_batch:
|
136 |
+
return filter_2d_by_kernel_1d(x, k)
|
137 |
+
elif k_dim == 2 or k_dim == 3 and is_batch:
|
138 |
+
return filter_2d_by_kernel_2d(x, k)
|
139 |
+
else:
|
140 |
+
raise ValueError('Kernel size should be one of (1, 2, 3).')
|
141 |
+
|
142 |
+
|
143 |
+
def gen_gauss_lowpass_filter_2d(
|
144 |
+
std: torch.Tensor,
|
145 |
+
window_size: int = None,
|
146 |
+
) -> torch.Tensor:
|
147 |
+
# Gaussian kernel size is odd in order to preserve the center.
|
148 |
+
if window_size is None:
|
149 |
+
window_size = (
|
150 |
+
2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
|
151 |
+
|
152 |
+
y = torch.arange(
|
153 |
+
window_size, dtype=std.dtype, device=std.device
|
154 |
+
).view(-1, 1).repeat(1, window_size)
|
155 |
+
grid = torch.stack((y.t(), y), dim=-1)
|
156 |
+
grid -= 0.5 * (window_size - 1) # (W, W)
|
157 |
+
var = (std * std).unsqueeze(-1).unsqueeze(-1)
|
158 |
+
distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
|
159 |
+
k = torch.exp(-0.5 * distsq / var)
|
160 |
+
k /= k.sum(dim=(-2, -1), keepdim=True)
|
161 |
+
return k
|
162 |
+
|
163 |
+
|
164 |
+
def gaussian_lowpass(
|
165 |
+
x: torch.Tensor,
|
166 |
+
std: Union[float, Tuple[float], torch.Tensor] = None,
|
167 |
+
cutoff: Union[float, torch.Tensor] = None,
|
168 |
+
scale: Union[float, torch.Tensor] = None,
|
169 |
+
) -> torch.Tensor:
|
170 |
+
if std is None:
|
171 |
+
cutoff = get_cutoff(cutoff, scale)
|
172 |
+
std = 0.5 / (np.pi * cutoff)
|
173 |
+
if isinstance(std, (float, int)):
|
174 |
+
std = (std, std)
|
175 |
+
if isinstance(std, torch.Tensor):
|
176 |
+
"""Using nn.functional.conv2d with Gaussian kernels built in runtime is
|
177 |
+
80% faster than transforms.functional.gaussian_blur for individual
|
178 |
+
items.
|
179 |
+
|
180 |
+
(in GPU); However, in CPU, the result is exactly opposite. But you
|
181 |
+
won't gonna run this on CPU, right?
|
182 |
+
"""
|
183 |
+
if len(list(s for s in std.shape if s != 1)) >= 2:
|
184 |
+
raise NotImplementedError(
|
185 |
+
'Anisotropic Gaussian filter is not currently available.')
|
186 |
+
|
187 |
+
# k.shape == (B, W, W).
|
188 |
+
k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
|
189 |
+
if k.shape[0] == 1:
|
190 |
+
return filter_by_kernel(x, k[0], False)
|
191 |
+
else:
|
192 |
+
return filter_by_kernel(x, k, True)
|
193 |
+
else:
|
194 |
+
# Gaussian kernel size is odd in order to preserve the center.
|
195 |
+
window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
|
196 |
+
return TF.gaussian_blur(x, window_size, std)
|
197 |
+
|
198 |
+
|
199 |
+
def blend(
|
200 |
+
fg: Union[torch.Tensor, Image.Image],
|
201 |
+
bg: Union[torch.Tensor, Image.Image],
|
202 |
+
mask: Union[torch.Tensor, Image.Image],
|
203 |
+
std: float = 0.0,
|
204 |
+
) -> Image.Image:
|
205 |
+
if not isinstance(fg, torch.Tensor):
|
206 |
+
fg = T.ToTensor()(fg)
|
207 |
+
if not isinstance(bg, torch.Tensor):
|
208 |
+
bg = T.ToTensor()(bg)
|
209 |
+
if not isinstance(mask, torch.Tensor):
|
210 |
+
mask = (T.ToTensor()(mask) < 0.5).float()[:1]
|
211 |
+
if std > 0:
|
212 |
+
mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
|
213 |
+
return T.ToPILImage()(fg * mask + bg * (1 - mask))
|
214 |
+
|
215 |
+
|
216 |
+
def get_panorama_views(
|
217 |
+
panorama_height: int,
|
218 |
+
panorama_width: int,
|
219 |
+
window_size: int = 64,
|
220 |
+
) -> tuple[List[Tuple[int]], torch.Tensor]:
|
221 |
+
stride = window_size // 2
|
222 |
+
is_horizontal = panorama_width > panorama_height
|
223 |
+
num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
|
224 |
+
num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
|
225 |
+
total_num_blocks = num_blocks_height * num_blocks_width
|
226 |
+
|
227 |
+
half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
|
228 |
+
half_rev = half_fwd.flip(0)
|
229 |
+
if window_size % 2 == 1:
|
230 |
+
half_rev = half_rev[1:]
|
231 |
+
c = torch.cat((half_fwd, half_rev))
|
232 |
+
one = torch.ones_like(c)
|
233 |
+
f = c.clone()
|
234 |
+
f[:window_size // 2] = 1
|
235 |
+
b = c.clone()
|
236 |
+
b[-(window_size // 2):] = 1
|
237 |
+
|
238 |
+
h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
|
239 |
+
w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
|
240 |
+
|
241 |
+
views = []
|
242 |
+
masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
|
243 |
+
for i in range(total_num_blocks):
|
244 |
+
hi, wi = i // num_blocks_width, i % num_blocks_width
|
245 |
+
h_start = hi * stride
|
246 |
+
h_end = min(h_start + window_size, panorama_height)
|
247 |
+
w_start = wi * stride
|
248 |
+
w_end = min(w_start + window_size, panorama_width)
|
249 |
+
views.append((h_start, h_end, w_start, w_end))
|
250 |
+
|
251 |
+
h_width = h_end - h_start
|
252 |
+
w_width = w_end - w_start
|
253 |
+
masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
|
254 |
+
|
255 |
+
# Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
|
256 |
+
return views, masks[None] # (1, n, h, w)
|
257 |
+
|
258 |
+
|
259 |
+
def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
|
260 |
+
h, w = mask.shape[-2:]
|
261 |
+
device = mask.device
|
262 |
+
mask = mask.reshape(-1, h, w)
|
263 |
+
# assert mask.shape[0] == im.shape[0]
|
264 |
+
h_occupied = mask.sum(dim=-2) > 0
|
265 |
+
w_occupied = mask.sum(dim=-1) > 0
|
266 |
+
l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
|
267 |
+
r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
|
268 |
+
t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
|
269 |
+
b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
|
270 |
+
tb = (t + b + 1) // 2
|
271 |
+
lr = (l + r + 1) // 2
|
272 |
+
shifts = (tb - (h // 2), lr - (w // 2))
|
273 |
+
shifts = torch.cat(shifts, dim=1) # (p, 2)
|
274 |
+
if reverse:
|
275 |
+
shifts = shifts * -1
|
276 |
+
return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
|
277 |
+
|
278 |
+
|
279 |
+
class Streamer:
|
280 |
+
def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
|
281 |
+
self.fn = fn
|
282 |
+
self.ema_alpha = ema_alpha
|
283 |
+
|
284 |
+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
285 |
+
self.future = self.executor.submit(fn)
|
286 |
+
self.image = None
|
287 |
+
|
288 |
+
self.prev_exec_time = 0
|
289 |
+
self.ema_exec_time = 0
|
290 |
+
|
291 |
+
@property
|
292 |
+
def throughput(self) -> float:
|
293 |
+
return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
|
294 |
+
|
295 |
+
def timed_fn(self) -> Any:
|
296 |
+
start = time.time()
|
297 |
+
res = self.fn()
|
298 |
+
end = time.time()
|
299 |
+
self.prev_exec_time = end - start
|
300 |
+
self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
|
301 |
+
return res
|
302 |
+
|
303 |
+
def __call__(self) -> Any:
|
304 |
+
if self.future.done() or self.image is None:
|
305 |
+
# get the result (the new image) and start a new task
|
306 |
+
image = self.future.result()
|
307 |
+
self.future = self.executor.submit(self.timed_fn)
|
308 |
+
self.image = image
|
309 |
+
return image
|
310 |
+
else:
|
311 |
+
# if self.fn() is not ready yet, use the previous image
|
312 |
+
# NOTE: This assumes that we have access to a previously generated image here.
|
313 |
+
# If there's no previous image (i.e., this is the first invocation), you could fall
|
314 |
+
# back to some default image or handle it differently based on your requirements.
|
315 |
+
return self.image
|