Spaces:
Runtime error
Runtime error
thejagstudio
commited on
Commit
•
510ee71
1
Parent(s):
d9087f2
Upload 61 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __init__.py +0 -0
- app.py +513 -0
- app_settings.py +94 -0
- backend/__init__.py +0 -0
- backend/annotators/canny_control.py +15 -0
- backend/annotators/control_interface.py +12 -0
- backend/annotators/depth_control.py +15 -0
- backend/annotators/image_control_factory.py +31 -0
- backend/annotators/lineart_control.py +11 -0
- backend/annotators/mlsd_control.py +10 -0
- backend/annotators/normal_control.py +10 -0
- backend/annotators/pose_control.py +10 -0
- backend/annotators/shuffle_control.py +10 -0
- backend/annotators/softedge_control.py +10 -0
- backend/api/models/response.py +16 -0
- backend/api/web.py +103 -0
- backend/base64_image.py +21 -0
- backend/controlnet.py +90 -0
- backend/device.py +23 -0
- backend/image_saver.py +60 -0
- backend/lcm_text_to_image.py +386 -0
- backend/lora.py +136 -0
- backend/models/device.py +9 -0
- backend/models/gen_images.py +16 -0
- backend/models/lcmdiffusion_setting.py +64 -0
- backend/models/upscale.py +9 -0
- backend/openvino/custom_ov_model_vae_decoder.py +21 -0
- backend/openvino/pipelines.py +75 -0
- backend/pipelines/lcm.py +100 -0
- backend/pipelines/lcm_lora.py +82 -0
- backend/tiny_decoder.py +32 -0
- backend/upscale/aura_sr.py +834 -0
- backend/upscale/aura_sr_upscale.py +9 -0
- backend/upscale/edsr_upscale_onnx.py +37 -0
- backend/upscale/tiled_upscale.py +238 -0
- backend/upscale/upscaler.py +52 -0
- constants.py +20 -0
- context.py +77 -0
- frontend/cli_interactive.py +655 -0
- frontend/gui/app_window.py +612 -0
- frontend/gui/image_generator_worker.py +37 -0
- frontend/gui/ui.py +15 -0
- frontend/utils.py +83 -0
- frontend/webui/controlnet_ui.py +194 -0
- frontend/webui/css/style.css +22 -0
- frontend/webui/generation_settings_ui.py +157 -0
- frontend/webui/image_to_image_ui.py +120 -0
- frontend/webui/image_variations_ui.py +106 -0
- frontend/webui/lora_models_ui.py +185 -0
- frontend/webui/models_ui.py +85 -0
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
import constants
|
5 |
+
from backend.controlnet import controlnet_settings_from_dict
|
6 |
+
from backend.models.gen_images import ImageFormat
|
7 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
8 |
+
from backend.upscale.tiled_upscale import generate_upscaled_image
|
9 |
+
from constants import APP_VERSION, DEVICE
|
10 |
+
from frontend.webui.image_variations_ui import generate_image_variations
|
11 |
+
from models.interface_types import InterfaceType
|
12 |
+
from paths import FastStableDiffusionPaths
|
13 |
+
from PIL import Image
|
14 |
+
from state import get_context, get_settings
|
15 |
+
from utils import show_system_info
|
16 |
+
from backend.device import get_device_name
|
17 |
+
|
18 |
+
parser = ArgumentParser(description=f"FAST SD CPU {constants.APP_VERSION}")
|
19 |
+
parser.add_argument(
|
20 |
+
"-s",
|
21 |
+
"--share",
|
22 |
+
action="store_true",
|
23 |
+
help="Create sharable link(Web UI)",
|
24 |
+
required=False,
|
25 |
+
)
|
26 |
+
group = parser.add_mutually_exclusive_group(required=False)
|
27 |
+
group.add_argument(
|
28 |
+
"-g",
|
29 |
+
"--gui",
|
30 |
+
action="store_true",
|
31 |
+
help="Start desktop GUI",
|
32 |
+
)
|
33 |
+
group.add_argument(
|
34 |
+
"-w",
|
35 |
+
"--webui",
|
36 |
+
action="store_true",
|
37 |
+
help="Start Web UI",
|
38 |
+
)
|
39 |
+
group.add_argument(
|
40 |
+
"-a",
|
41 |
+
"--api",
|
42 |
+
action="store_true",
|
43 |
+
help="Start Web API server",
|
44 |
+
)
|
45 |
+
group.add_argument(
|
46 |
+
"-r",
|
47 |
+
"--realtime",
|
48 |
+
action="store_true",
|
49 |
+
help="Start realtime inference UI(experimental)",
|
50 |
+
)
|
51 |
+
group.add_argument(
|
52 |
+
"-v",
|
53 |
+
"--version",
|
54 |
+
action="store_true",
|
55 |
+
help="Version",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"-b",
|
60 |
+
"--benchmark",
|
61 |
+
action="store_true",
|
62 |
+
help="Run inference benchmark on the selected device",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--lcm_model_id",
|
66 |
+
type=str,
|
67 |
+
help="Model ID or path,Default stabilityai/sd-turbo",
|
68 |
+
default="stabilityai/sd-turbo",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--openvino_lcm_model_id",
|
72 |
+
type=str,
|
73 |
+
help="OpenVINO Model ID or path,Default rupeshs/sd-turbo-openvino",
|
74 |
+
default="rupeshs/sd-turbo-openvino",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--prompt",
|
78 |
+
type=str,
|
79 |
+
help="Describe the image you want to generate",
|
80 |
+
default="",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--negative_prompt",
|
84 |
+
type=str,
|
85 |
+
help="Describe what you want to exclude from the generation",
|
86 |
+
default="",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--image_height",
|
90 |
+
type=int,
|
91 |
+
help="Height of the image",
|
92 |
+
default=512,
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--image_width",
|
96 |
+
type=int,
|
97 |
+
help="Width of the image",
|
98 |
+
default=512,
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--inference_steps",
|
102 |
+
type=int,
|
103 |
+
help="Number of steps,default : 1",
|
104 |
+
default=1,
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--guidance_scale",
|
108 |
+
type=float,
|
109 |
+
help="Guidance scale,default : 1.0",
|
110 |
+
default=1.0,
|
111 |
+
)
|
112 |
+
|
113 |
+
parser.add_argument(
|
114 |
+
"--number_of_images",
|
115 |
+
type=int,
|
116 |
+
help="Number of images to generate ,default : 1",
|
117 |
+
default=1,
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--seed",
|
121 |
+
type=int,
|
122 |
+
help="Seed,default : -1 (disabled) ",
|
123 |
+
default=-1,
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--use_openvino",
|
127 |
+
action="store_true",
|
128 |
+
help="Use OpenVINO model",
|
129 |
+
)
|
130 |
+
|
131 |
+
parser.add_argument(
|
132 |
+
"--use_offline_model",
|
133 |
+
action="store_true",
|
134 |
+
help="Use offline model",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--use_safety_checker",
|
138 |
+
action="store_true",
|
139 |
+
help="Use safety checker",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--use_lcm_lora",
|
143 |
+
action="store_true",
|
144 |
+
help="Use LCM-LoRA",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--base_model_id",
|
148 |
+
type=str,
|
149 |
+
help="LCM LoRA base model ID,Default Lykon/dreamshaper-8",
|
150 |
+
default="Lykon/dreamshaper-8",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--lcm_lora_id",
|
154 |
+
type=str,
|
155 |
+
help="LCM LoRA model ID,Default latent-consistency/lcm-lora-sdv1-5",
|
156 |
+
default="latent-consistency/lcm-lora-sdv1-5",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"-i",
|
160 |
+
"--interactive",
|
161 |
+
action="store_true",
|
162 |
+
help="Interactive CLI mode",
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"-t",
|
166 |
+
"--use_tiny_auto_encoder",
|
167 |
+
action="store_true",
|
168 |
+
help="Use tiny auto encoder for SD (TAESD)",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"-f",
|
172 |
+
"--file",
|
173 |
+
type=str,
|
174 |
+
help="Input image for img2img mode",
|
175 |
+
default="",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--img2img",
|
179 |
+
action="store_true",
|
180 |
+
help="img2img mode; requires input file via -f argument",
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--batch_count",
|
184 |
+
type=int,
|
185 |
+
help="Number of sequential generations",
|
186 |
+
default=1,
|
187 |
+
)
|
188 |
+
parser.add_argument(
|
189 |
+
"--strength",
|
190 |
+
type=float,
|
191 |
+
help="Denoising strength for img2img and Image variations",
|
192 |
+
default=0.3,
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--sdupscale",
|
196 |
+
action="store_true",
|
197 |
+
help="Tiled SD upscale,works only for the resolution 512x512,(2x upscale)",
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--upscale",
|
201 |
+
action="store_true",
|
202 |
+
help="EDSR SD upscale ",
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--custom_settings",
|
206 |
+
type=str,
|
207 |
+
help="JSON file containing custom generation settings",
|
208 |
+
default=None,
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--usejpeg",
|
212 |
+
action="store_true",
|
213 |
+
help="Images will be saved as JPEG format",
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--noimagesave",
|
217 |
+
action="store_true",
|
218 |
+
help="Disable image saving",
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--lora",
|
222 |
+
type=str,
|
223 |
+
help="LoRA model full path e.g D:\lora_models\CuteCartoon15V-LiberteRedmodModel-Cartoon-CuteCartoonAF.safetensors",
|
224 |
+
default=None,
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--lora_weight",
|
228 |
+
type=float,
|
229 |
+
help="LoRA adapter weight [0 to 1.0]",
|
230 |
+
default=0.5,
|
231 |
+
)
|
232 |
+
|
233 |
+
args = parser.parse_args()
|
234 |
+
|
235 |
+
if args.version:
|
236 |
+
print(APP_VERSION)
|
237 |
+
exit()
|
238 |
+
|
239 |
+
# parser.print_help()
|
240 |
+
show_system_info()
|
241 |
+
print(f"Using device : {constants.DEVICE}")
|
242 |
+
|
243 |
+
if args.webui:
|
244 |
+
app_settings = get_settings()
|
245 |
+
else:
|
246 |
+
app_settings = get_settings()
|
247 |
+
|
248 |
+
print(f"Found {len(app_settings.lcm_models)} LCM models in config/lcm-models.txt")
|
249 |
+
print(
|
250 |
+
f"Found {len(app_settings.stable_diffsuion_models)} stable diffusion models in config/stable-diffusion-models.txt"
|
251 |
+
)
|
252 |
+
print(
|
253 |
+
f"Found {len(app_settings.lcm_lora_models)} LCM-LoRA models in config/lcm-lora-models.txt"
|
254 |
+
)
|
255 |
+
print(
|
256 |
+
f"Found {len(app_settings.openvino_lcm_models)} OpenVINO LCM models in config/openvino-lcm-models.txt"
|
257 |
+
)
|
258 |
+
|
259 |
+
if args.noimagesave:
|
260 |
+
app_settings.settings.generated_images.save_image = False
|
261 |
+
else:
|
262 |
+
app_settings.settings.generated_images.save_image = True
|
263 |
+
|
264 |
+
if not args.realtime:
|
265 |
+
# To minimize realtime mode dependencies
|
266 |
+
from backend.upscale.upscaler import upscale_image
|
267 |
+
from frontend.cli_interactive import interactive_mode
|
268 |
+
|
269 |
+
if args.gui:
|
270 |
+
from frontend.gui.ui import start_gui
|
271 |
+
|
272 |
+
print("Starting desktop GUI mode(Qt)")
|
273 |
+
start_gui(
|
274 |
+
[],
|
275 |
+
app_settings,
|
276 |
+
)
|
277 |
+
elif args.webui:
|
278 |
+
from frontend.webui.ui import start_webui
|
279 |
+
|
280 |
+
print("Starting web UI mode")
|
281 |
+
start_webui(
|
282 |
+
args.share,
|
283 |
+
)
|
284 |
+
elif args.realtime:
|
285 |
+
from frontend.webui.realtime_ui import start_realtime_text_to_image
|
286 |
+
|
287 |
+
print("Starting realtime text to image(EXPERIMENTAL)")
|
288 |
+
start_realtime_text_to_image(args.share)
|
289 |
+
elif args.api:
|
290 |
+
from backend.api.web import start_web_server
|
291 |
+
|
292 |
+
start_web_server()
|
293 |
+
|
294 |
+
else:
|
295 |
+
context = get_context(InterfaceType.CLI)
|
296 |
+
config = app_settings.settings
|
297 |
+
|
298 |
+
if args.use_openvino:
|
299 |
+
config.lcm_diffusion_setting.openvino_lcm_model_id = args.openvino_lcm_model_id
|
300 |
+
else:
|
301 |
+
config.lcm_diffusion_setting.lcm_model_id = args.lcm_model_id
|
302 |
+
|
303 |
+
config.lcm_diffusion_setting.prompt = args.prompt
|
304 |
+
config.lcm_diffusion_setting.negative_prompt = args.negative_prompt
|
305 |
+
config.lcm_diffusion_setting.image_height = args.image_height
|
306 |
+
config.lcm_diffusion_setting.image_width = args.image_width
|
307 |
+
config.lcm_diffusion_setting.guidance_scale = args.guidance_scale
|
308 |
+
config.lcm_diffusion_setting.number_of_images = args.number_of_images
|
309 |
+
config.lcm_diffusion_setting.inference_steps = args.inference_steps
|
310 |
+
config.lcm_diffusion_setting.strength = args.strength
|
311 |
+
config.lcm_diffusion_setting.seed = args.seed
|
312 |
+
config.lcm_diffusion_setting.use_openvino = args.use_openvino
|
313 |
+
config.lcm_diffusion_setting.use_tiny_auto_encoder = args.use_tiny_auto_encoder
|
314 |
+
config.lcm_diffusion_setting.use_lcm_lora = args.use_lcm_lora
|
315 |
+
config.lcm_diffusion_setting.lcm_lora.base_model_id = args.base_model_id
|
316 |
+
config.lcm_diffusion_setting.lcm_lora.lcm_lora_id = args.lcm_lora_id
|
317 |
+
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
|
318 |
+
config.lcm_diffusion_setting.lora.enabled = False
|
319 |
+
config.lcm_diffusion_setting.lora.path = args.lora
|
320 |
+
config.lcm_diffusion_setting.lora.weight = args.lora_weight
|
321 |
+
config.lcm_diffusion_setting.lora.fuse = True
|
322 |
+
if config.lcm_diffusion_setting.lora.path:
|
323 |
+
config.lcm_diffusion_setting.lora.enabled = True
|
324 |
+
if args.usejpeg:
|
325 |
+
config.generated_images.format = ImageFormat.JPEG.value.upper()
|
326 |
+
if args.seed > -1:
|
327 |
+
config.lcm_diffusion_setting.use_seed = True
|
328 |
+
else:
|
329 |
+
config.lcm_diffusion_setting.use_seed = False
|
330 |
+
config.lcm_diffusion_setting.use_offline_model = args.use_offline_model
|
331 |
+
config.lcm_diffusion_setting.use_safety_checker = args.use_safety_checker
|
332 |
+
|
333 |
+
# Read custom settings from JSON file
|
334 |
+
custom_settings = {}
|
335 |
+
if args.custom_settings:
|
336 |
+
with open(args.custom_settings) as f:
|
337 |
+
custom_settings = json.load(f)
|
338 |
+
|
339 |
+
# Basic ControlNet settings; if ControlNet is enabled, an image is
|
340 |
+
# required even in txt2img mode
|
341 |
+
config.lcm_diffusion_setting.controlnet = None
|
342 |
+
controlnet_settings_from_dict(
|
343 |
+
config.lcm_diffusion_setting,
|
344 |
+
custom_settings,
|
345 |
+
)
|
346 |
+
|
347 |
+
# Interactive mode
|
348 |
+
if args.interactive:
|
349 |
+
# wrapper(interactive_mode, config, context)
|
350 |
+
config.lcm_diffusion_setting.lora.fuse = False
|
351 |
+
interactive_mode(config, context)
|
352 |
+
|
353 |
+
# Start of non-interactive CLI image generation
|
354 |
+
if args.img2img and args.file != "":
|
355 |
+
config.lcm_diffusion_setting.init_image = Image.open(args.file)
|
356 |
+
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
|
357 |
+
elif args.img2img and args.file == "":
|
358 |
+
print("Error : You need to specify a file in img2img mode")
|
359 |
+
exit()
|
360 |
+
elif args.upscale and args.file == "" and args.custom_settings == None:
|
361 |
+
print("Error : You need to specify a file in SD upscale mode")
|
362 |
+
exit()
|
363 |
+
elif (
|
364 |
+
args.prompt == ""
|
365 |
+
and args.file == ""
|
366 |
+
and args.custom_settings == None
|
367 |
+
and not args.benchmark
|
368 |
+
):
|
369 |
+
print("Error : You need to provide a prompt")
|
370 |
+
exit()
|
371 |
+
|
372 |
+
if args.upscale:
|
373 |
+
# image = Image.open(args.file)
|
374 |
+
output_path = FastStableDiffusionPaths.get_upscale_filepath(
|
375 |
+
args.file,
|
376 |
+
2,
|
377 |
+
config.generated_images.format,
|
378 |
+
)
|
379 |
+
result = upscale_image(
|
380 |
+
context,
|
381 |
+
args.file,
|
382 |
+
output_path,
|
383 |
+
2,
|
384 |
+
)
|
385 |
+
# Perform Tiled SD upscale (EXPERIMENTAL)
|
386 |
+
elif args.sdupscale:
|
387 |
+
if args.use_openvino:
|
388 |
+
config.lcm_diffusion_setting.strength = 0.3
|
389 |
+
upscale_settings = None
|
390 |
+
if custom_settings != {}:
|
391 |
+
upscale_settings = custom_settings
|
392 |
+
filepath = args.file
|
393 |
+
output_format = config.generated_images.format
|
394 |
+
if upscale_settings:
|
395 |
+
filepath = upscale_settings["source_file"]
|
396 |
+
output_format = upscale_settings["output_format"].upper()
|
397 |
+
output_path = FastStableDiffusionPaths.get_upscale_filepath(
|
398 |
+
filepath,
|
399 |
+
2,
|
400 |
+
output_format,
|
401 |
+
)
|
402 |
+
|
403 |
+
generate_upscaled_image(
|
404 |
+
config,
|
405 |
+
filepath,
|
406 |
+
config.lcm_diffusion_setting.strength,
|
407 |
+
upscale_settings=upscale_settings,
|
408 |
+
context=context,
|
409 |
+
tile_overlap=32 if config.lcm_diffusion_setting.use_openvino else 16,
|
410 |
+
output_path=output_path,
|
411 |
+
image_format=output_format,
|
412 |
+
)
|
413 |
+
exit()
|
414 |
+
# If img2img argument is set and prompt is empty, use image variations mode
|
415 |
+
elif args.img2img and args.prompt == "":
|
416 |
+
for i in range(0, args.batch_count):
|
417 |
+
generate_image_variations(
|
418 |
+
config.lcm_diffusion_setting.init_image, args.strength
|
419 |
+
)
|
420 |
+
else:
|
421 |
+
|
422 |
+
if args.benchmark:
|
423 |
+
print("Initializing benchmark...")
|
424 |
+
bench_lcm_setting = config.lcm_diffusion_setting
|
425 |
+
bench_lcm_setting.prompt = "a cat"
|
426 |
+
bench_lcm_setting.use_tiny_auto_encoder = False
|
427 |
+
context.generate_text_to_image(
|
428 |
+
settings=config,
|
429 |
+
device=DEVICE,
|
430 |
+
)
|
431 |
+
latencies = []
|
432 |
+
|
433 |
+
print("Starting benchmark please wait...")
|
434 |
+
for _ in range(3):
|
435 |
+
context.generate_text_to_image(
|
436 |
+
settings=config,
|
437 |
+
device=DEVICE,
|
438 |
+
)
|
439 |
+
latencies.append(context.latency)
|
440 |
+
|
441 |
+
avg_latency = sum(latencies) / 3
|
442 |
+
|
443 |
+
bench_lcm_setting.use_tiny_auto_encoder = True
|
444 |
+
|
445 |
+
context.generate_text_to_image(
|
446 |
+
settings=config,
|
447 |
+
device=DEVICE,
|
448 |
+
)
|
449 |
+
latencies = []
|
450 |
+
for _ in range(3):
|
451 |
+
context.generate_text_to_image(
|
452 |
+
settings=config,
|
453 |
+
device=DEVICE,
|
454 |
+
)
|
455 |
+
latencies.append(context.latency)
|
456 |
+
|
457 |
+
avg_latency_taesd = sum(latencies) / 3
|
458 |
+
|
459 |
+
benchmark_name = ""
|
460 |
+
|
461 |
+
if config.lcm_diffusion_setting.use_openvino:
|
462 |
+
benchmark_name = "OpenVINO"
|
463 |
+
else:
|
464 |
+
benchmark_name = "PyTorch"
|
465 |
+
|
466 |
+
bench_model_id = ""
|
467 |
+
if bench_lcm_setting.use_openvino:
|
468 |
+
bench_model_id = bench_lcm_setting.openvino_lcm_model_id
|
469 |
+
elif bench_lcm_setting.use_lcm_lora:
|
470 |
+
bench_model_id = bench_lcm_setting.lcm_lora.base_model_id
|
471 |
+
else:
|
472 |
+
bench_model_id = bench_lcm_setting.lcm_model_id
|
473 |
+
|
474 |
+
benchmark_result = [
|
475 |
+
["Device", f"{DEVICE.upper()},{get_device_name()}"],
|
476 |
+
["Stable Diffusion Model", bench_model_id],
|
477 |
+
[
|
478 |
+
"Image Size ",
|
479 |
+
f"{bench_lcm_setting.image_width}x{bench_lcm_setting.image_height}",
|
480 |
+
],
|
481 |
+
[
|
482 |
+
"Inference Steps",
|
483 |
+
f"{bench_lcm_setting.inference_steps}",
|
484 |
+
],
|
485 |
+
[
|
486 |
+
"Benchmark Passes",
|
487 |
+
3,
|
488 |
+
],
|
489 |
+
[
|
490 |
+
"Average Latency",
|
491 |
+
f"{round(avg_latency,3)} sec",
|
492 |
+
],
|
493 |
+
[
|
494 |
+
"Average Latency(TAESD* enabled)",
|
495 |
+
f"{round(avg_latency_taesd,3)} sec",
|
496 |
+
],
|
497 |
+
]
|
498 |
+
print()
|
499 |
+
print(
|
500 |
+
f" FastSD Benchmark - {benchmark_name:8} "
|
501 |
+
)
|
502 |
+
print(f"-" * 80)
|
503 |
+
for benchmark in benchmark_result:
|
504 |
+
print(f"{benchmark[0]:35} - {benchmark[1]}")
|
505 |
+
print(f"-" * 80)
|
506 |
+
print("*TAESD - Tiny AutoEncoder for Stable Diffusion")
|
507 |
+
|
508 |
+
else:
|
509 |
+
for i in range(0, args.batch_count):
|
510 |
+
context.generate_text_to_image(
|
511 |
+
settings=config,
|
512 |
+
device=DEVICE,
|
513 |
+
)
|
app_settings.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
from os import path, makedirs
|
3 |
+
from models.settings import Settings
|
4 |
+
from paths import FastStableDiffusionPaths
|
5 |
+
from utils import get_models_from_text_file
|
6 |
+
from constants import (
|
7 |
+
OPENVINO_LCM_MODELS_FILE,
|
8 |
+
LCM_LORA_MODELS_FILE,
|
9 |
+
SD_MODELS_FILE,
|
10 |
+
LCM_MODELS_FILE,
|
11 |
+
)
|
12 |
+
from copy import deepcopy
|
13 |
+
|
14 |
+
|
15 |
+
class AppSettings:
|
16 |
+
def __init__(self):
|
17 |
+
self.config_path = FastStableDiffusionPaths().get_app_settings_path()
|
18 |
+
self._stable_diffsuion_models = get_models_from_text_file(
|
19 |
+
FastStableDiffusionPaths().get_models_config_path(SD_MODELS_FILE)
|
20 |
+
)
|
21 |
+
self._lcm_lora_models = get_models_from_text_file(
|
22 |
+
FastStableDiffusionPaths().get_models_config_path(LCM_LORA_MODELS_FILE)
|
23 |
+
)
|
24 |
+
self._openvino_lcm_models = get_models_from_text_file(
|
25 |
+
FastStableDiffusionPaths().get_models_config_path(OPENVINO_LCM_MODELS_FILE)
|
26 |
+
)
|
27 |
+
self._lcm_models = get_models_from_text_file(
|
28 |
+
FastStableDiffusionPaths().get_models_config_path(LCM_MODELS_FILE)
|
29 |
+
)
|
30 |
+
self._config = None
|
31 |
+
|
32 |
+
@property
|
33 |
+
def settings(self):
|
34 |
+
return self._config
|
35 |
+
|
36 |
+
@property
|
37 |
+
def stable_diffsuion_models(self):
|
38 |
+
return self._stable_diffsuion_models
|
39 |
+
|
40 |
+
@property
|
41 |
+
def openvino_lcm_models(self):
|
42 |
+
return self._openvino_lcm_models
|
43 |
+
|
44 |
+
@property
|
45 |
+
def lcm_models(self):
|
46 |
+
return self._lcm_models
|
47 |
+
|
48 |
+
@property
|
49 |
+
def lcm_lora_models(self):
|
50 |
+
return self._lcm_lora_models
|
51 |
+
|
52 |
+
def load(self, skip_file=False):
|
53 |
+
if skip_file:
|
54 |
+
print("Skipping config file")
|
55 |
+
settings_dict = self._load_default()
|
56 |
+
self._config = Settings.model_validate(settings_dict)
|
57 |
+
else:
|
58 |
+
if not path.exists(self.config_path):
|
59 |
+
base_dir = path.dirname(self.config_path)
|
60 |
+
if not path.exists(base_dir):
|
61 |
+
makedirs(base_dir)
|
62 |
+
try:
|
63 |
+
print("Settings not found creating default settings")
|
64 |
+
with open(self.config_path, "w") as file:
|
65 |
+
yaml.dump(
|
66 |
+
self._load_default(),
|
67 |
+
file,
|
68 |
+
)
|
69 |
+
except Exception as ex:
|
70 |
+
print(f"Error in creating settings : {ex}")
|
71 |
+
exit()
|
72 |
+
try:
|
73 |
+
with open(self.config_path) as file:
|
74 |
+
settings_dict = yaml.safe_load(file)
|
75 |
+
self._config = Settings.model_validate(settings_dict)
|
76 |
+
except Exception as ex:
|
77 |
+
print(f"Error in loading settings : {ex}")
|
78 |
+
|
79 |
+
def save(self):
|
80 |
+
try:
|
81 |
+
with open(self.config_path, "w") as file:
|
82 |
+
tmp_cfg = deepcopy(self._config)
|
83 |
+
tmp_cfg.lcm_diffusion_setting.init_image = None
|
84 |
+
configurations = tmp_cfg.model_dump(
|
85 |
+
exclude=["init_image"],
|
86 |
+
)
|
87 |
+
if configurations:
|
88 |
+
yaml.dump(configurations, file)
|
89 |
+
except Exception as ex:
|
90 |
+
print(f"Error in saving settings : {ex}")
|
91 |
+
|
92 |
+
def _load_default(self) -> dict:
|
93 |
+
default_config = Settings()
|
94 |
+
return default_config.model_dump()
|
backend/__init__.py
ADDED
File without changes
|
backend/annotators/canny_control.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from backend.annotators.control_interface import ControlInterface
|
3 |
+
from cv2 import Canny
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
class CannyControl(ControlInterface):
|
8 |
+
def get_control_image(self, image: Image) -> Image:
|
9 |
+
low_threshold = 100
|
10 |
+
high_threshold = 200
|
11 |
+
image = np.array(image)
|
12 |
+
image = Canny(image, low_threshold, high_threshold)
|
13 |
+
image = image[:, :, None]
|
14 |
+
image = np.concatenate([image, image, image], axis=2)
|
15 |
+
return Image.fromarray(image)
|
backend/annotators/control_interface.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class ControlInterface(ABC):
|
7 |
+
@abstractmethod
|
8 |
+
def get_control_image(
|
9 |
+
self,
|
10 |
+
image: Image,
|
11 |
+
) -> Image:
|
12 |
+
pass
|
backend/annotators/depth_control.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from backend.annotators.control_interface import ControlInterface
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
|
7 |
+
class DepthControl(ControlInterface):
|
8 |
+
def get_control_image(self, image: Image) -> Image:
|
9 |
+
depth_estimator = pipeline("depth-estimation")
|
10 |
+
image = depth_estimator(image)["depth"]
|
11 |
+
image = np.array(image)
|
12 |
+
image = image[:, :, None]
|
13 |
+
image = np.concatenate([image, image, image], axis=2)
|
14 |
+
image = Image.fromarray(image)
|
15 |
+
return image
|
backend/annotators/image_control_factory.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.annotators.canny_control import CannyControl
|
2 |
+
from backend.annotators.depth_control import DepthControl
|
3 |
+
from backend.annotators.lineart_control import LineArtControl
|
4 |
+
from backend.annotators.mlsd_control import MlsdControl
|
5 |
+
from backend.annotators.normal_control import NormalControl
|
6 |
+
from backend.annotators.pose_control import PoseControl
|
7 |
+
from backend.annotators.shuffle_control import ShuffleControl
|
8 |
+
from backend.annotators.softedge_control import SoftEdgeControl
|
9 |
+
|
10 |
+
|
11 |
+
class ImageControlFactory:
|
12 |
+
def create_control(self, controlnet_type: str):
|
13 |
+
if controlnet_type == "Canny":
|
14 |
+
return CannyControl()
|
15 |
+
elif controlnet_type == "Pose":
|
16 |
+
return PoseControl()
|
17 |
+
elif controlnet_type == "MLSD":
|
18 |
+
return MlsdControl()
|
19 |
+
elif controlnet_type == "Depth":
|
20 |
+
return DepthControl()
|
21 |
+
elif controlnet_type == "LineArt":
|
22 |
+
return LineArtControl()
|
23 |
+
elif controlnet_type == "Shuffle":
|
24 |
+
return ShuffleControl()
|
25 |
+
elif controlnet_type == "NormalBAE":
|
26 |
+
return NormalControl()
|
27 |
+
elif controlnet_type == "SoftEdge":
|
28 |
+
return SoftEdgeControl()
|
29 |
+
else:
|
30 |
+
print("Error: Control type not implemented!")
|
31 |
+
raise Exception("Error: Control type not implemented!")
|
backend/annotators/lineart_control.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from backend.annotators.control_interface import ControlInterface
|
3 |
+
from controlnet_aux import LineartDetector
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
class LineArtControl(ControlInterface):
|
8 |
+
def get_control_image(self, image: Image) -> Image:
|
9 |
+
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
10 |
+
control_image = processor(image)
|
11 |
+
return control_image
|
backend/annotators/mlsd_control.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.annotators.control_interface import ControlInterface
|
2 |
+
from controlnet_aux import MLSDdetector
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class MlsdControl(ControlInterface):
|
7 |
+
def get_control_image(self, image: Image) -> Image:
|
8 |
+
mlsd = MLSDdetector.from_pretrained("lllyasviel/ControlNet")
|
9 |
+
image = mlsd(image)
|
10 |
+
return image
|
backend/annotators/normal_control.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.annotators.control_interface import ControlInterface
|
2 |
+
from controlnet_aux import NormalBaeDetector
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class NormalControl(ControlInterface):
|
7 |
+
def get_control_image(self, image: Image) -> Image:
|
8 |
+
processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
9 |
+
control_image = processor(image)
|
10 |
+
return control_image
|
backend/annotators/pose_control.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.annotators.control_interface import ControlInterface
|
2 |
+
from controlnet_aux import OpenposeDetector
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class PoseControl(ControlInterface):
|
7 |
+
def get_control_image(self, image: Image) -> Image:
|
8 |
+
openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
9 |
+
image = openpose(image)
|
10 |
+
return image
|
backend/annotators/shuffle_control.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.annotators.control_interface import ControlInterface
|
2 |
+
from controlnet_aux import ContentShuffleDetector
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class ShuffleControl(ControlInterface):
|
7 |
+
def get_control_image(self, image: Image) -> Image:
|
8 |
+
shuffle_processor = ContentShuffleDetector()
|
9 |
+
image = shuffle_processor(image)
|
10 |
+
return image
|
backend/annotators/softedge_control.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.annotators.control_interface import ControlInterface
|
2 |
+
from controlnet_aux import PidiNetDetector
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class SoftEdgeControl(ControlInterface):
|
7 |
+
def get_control_image(self, image: Image) -> Image:
|
8 |
+
processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
9 |
+
control_image = processor(image)
|
10 |
+
return control_image
|
backend/api/models/response.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class StableDiffusionResponse(BaseModel):
|
7 |
+
"""
|
8 |
+
Stable diffusion response model
|
9 |
+
|
10 |
+
Attributes:
|
11 |
+
images (List[str]): List of JPEG image as base64 encoded
|
12 |
+
latency (float): Latency in seconds
|
13 |
+
"""
|
14 |
+
|
15 |
+
images: List[str]
|
16 |
+
latency: float
|
backend/api/web.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
|
3 |
+
import uvicorn
|
4 |
+
from backend.api.models.response import StableDiffusionResponse
|
5 |
+
from backend.models.device import DeviceInfo
|
6 |
+
from backend.base64_image import base64_image_to_pil, pil_image_to_base64_str
|
7 |
+
from backend.device import get_device_name
|
8 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask, LCMDiffusionSetting
|
9 |
+
from constants import APP_VERSION, DEVICE
|
10 |
+
from context import Context
|
11 |
+
from fastapi import FastAPI
|
12 |
+
from models.interface_types import InterfaceType
|
13 |
+
from state import get_settings
|
14 |
+
|
15 |
+
app_settings = get_settings()
|
16 |
+
app = FastAPI(
|
17 |
+
title="FastSD CPU",
|
18 |
+
description="Fast stable diffusion on CPU",
|
19 |
+
version=APP_VERSION,
|
20 |
+
license_info={
|
21 |
+
"name": "MIT",
|
22 |
+
"identifier": "MIT",
|
23 |
+
},
|
24 |
+
docs_url="/api/docs",
|
25 |
+
redoc_url="/api/redoc",
|
26 |
+
openapi_url="/api/openapi.json",
|
27 |
+
)
|
28 |
+
print(app_settings.settings.lcm_diffusion_setting)
|
29 |
+
|
30 |
+
context = Context(InterfaceType.API_SERVER)
|
31 |
+
|
32 |
+
|
33 |
+
@app.get("/api/")
|
34 |
+
async def root():
|
35 |
+
return {"message": "Welcome to FastSD CPU API"}
|
36 |
+
|
37 |
+
|
38 |
+
@app.get(
|
39 |
+
"/api/info",
|
40 |
+
description="Get system information",
|
41 |
+
summary="Get system information",
|
42 |
+
)
|
43 |
+
async def info():
|
44 |
+
device_info = DeviceInfo(
|
45 |
+
device_type=DEVICE,
|
46 |
+
device_name=get_device_name(),
|
47 |
+
os=platform.system(),
|
48 |
+
platform=platform.platform(),
|
49 |
+
processor=platform.processor(),
|
50 |
+
)
|
51 |
+
return device_info.model_dump()
|
52 |
+
|
53 |
+
|
54 |
+
@app.get(
|
55 |
+
"/api/config",
|
56 |
+
description="Get current configuration",
|
57 |
+
summary="Get configurations",
|
58 |
+
)
|
59 |
+
async def config():
|
60 |
+
return app_settings.settings
|
61 |
+
|
62 |
+
|
63 |
+
@app.get(
|
64 |
+
"/api/models",
|
65 |
+
description="Get available models",
|
66 |
+
summary="Get available models",
|
67 |
+
)
|
68 |
+
async def models():
|
69 |
+
return {
|
70 |
+
"lcm_lora_models": app_settings.lcm_lora_models,
|
71 |
+
"stable_diffusion": app_settings.stable_diffsuion_models,
|
72 |
+
"openvino_models": app_settings.openvino_lcm_models,
|
73 |
+
"lcm_models": app_settings.lcm_models,
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
@app.post(
|
78 |
+
"/api/generate",
|
79 |
+
description="Generate image(Text to image,Image to Image)",
|
80 |
+
summary="Generate image(Text to image,Image to Image)",
|
81 |
+
)
|
82 |
+
async def generate(diffusion_config: LCMDiffusionSetting) -> StableDiffusionResponse:
|
83 |
+
app_settings.settings.lcm_diffusion_setting = diffusion_config
|
84 |
+
if diffusion_config.diffusion_task == DiffusionTask.image_to_image:
|
85 |
+
app_settings.settings.lcm_diffusion_setting.init_image = base64_image_to_pil(
|
86 |
+
diffusion_config.init_image
|
87 |
+
)
|
88 |
+
|
89 |
+
images = context.generate_text_to_image(app_settings.settings)
|
90 |
+
|
91 |
+
images_base64 = [pil_image_to_base64_str(img) for img in images]
|
92 |
+
return StableDiffusionResponse(
|
93 |
+
latency=round(context.latency, 2),
|
94 |
+
images=images_base64,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def start_web_server():
|
99 |
+
uvicorn.run(
|
100 |
+
app,
|
101 |
+
host="0.0.0.0",
|
102 |
+
port=8000,
|
103 |
+
)
|
backend/base64_image.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
from base64 import b64encode, b64decode
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def pil_image_to_base64_str(
|
7 |
+
image: Image,
|
8 |
+
format: str = "JPEG",
|
9 |
+
) -> str:
|
10 |
+
buffer = BytesIO()
|
11 |
+
image.save(buffer, format=format)
|
12 |
+
buffer.seek(0)
|
13 |
+
img_base64 = b64encode(buffer.getvalue()).decode("utf-8")
|
14 |
+
return img_base64
|
15 |
+
|
16 |
+
|
17 |
+
def base64_image_to_pil(base64_str) -> Image:
|
18 |
+
image_data = b64decode(base64_str)
|
19 |
+
image_buffer = BytesIO(image_data)
|
20 |
+
image = Image.open(image_buffer)
|
21 |
+
return image
|
backend/controlnet.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from PIL import Image
|
3 |
+
from diffusers import ControlNetModel
|
4 |
+
from backend.models.lcmdiffusion_setting import (
|
5 |
+
DiffusionTask,
|
6 |
+
ControlNetSetting,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
# Prepares ControlNet adapters for use with FastSD CPU
|
11 |
+
#
|
12 |
+
# This function loads the ControlNet adapters defined by the
|
13 |
+
# _lcm_diffusion_setting.controlnet_ object and returns a dictionary
|
14 |
+
# with the pipeline arguments required to use the loaded adapters
|
15 |
+
def load_controlnet_adapters(lcm_diffusion_setting) -> dict:
|
16 |
+
controlnet_args = {}
|
17 |
+
if (
|
18 |
+
lcm_diffusion_setting.controlnet is None
|
19 |
+
or not lcm_diffusion_setting.controlnet.enabled
|
20 |
+
):
|
21 |
+
return controlnet_args
|
22 |
+
|
23 |
+
logging.info("Loading ControlNet adapter")
|
24 |
+
controlnet_adapter = ControlNetModel.from_single_file(
|
25 |
+
lcm_diffusion_setting.controlnet.adapter_path,
|
26 |
+
local_files_only=True,
|
27 |
+
use_safetensors=True,
|
28 |
+
)
|
29 |
+
controlnet_args["controlnet"] = controlnet_adapter
|
30 |
+
return controlnet_args
|
31 |
+
|
32 |
+
|
33 |
+
# Updates the ControlNet pipeline arguments to use for image generation
|
34 |
+
#
|
35 |
+
# This function uses the contents of the _lcm_diffusion_setting.controlnet_
|
36 |
+
# object to generate a dictionary with the corresponding pipeline arguments
|
37 |
+
# to be used for image generation; in particular, it sets the ControlNet control
|
38 |
+
# image and conditioning scale
|
39 |
+
def update_controlnet_arguments(lcm_diffusion_setting) -> dict:
|
40 |
+
controlnet_args = {}
|
41 |
+
if (
|
42 |
+
lcm_diffusion_setting.controlnet is None
|
43 |
+
or not lcm_diffusion_setting.controlnet.enabled
|
44 |
+
):
|
45 |
+
return controlnet_args
|
46 |
+
|
47 |
+
controlnet_args["controlnet_conditioning_scale"] = (
|
48 |
+
lcm_diffusion_setting.controlnet.conditioning_scale
|
49 |
+
)
|
50 |
+
if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
|
51 |
+
controlnet_args["image"] = lcm_diffusion_setting.controlnet._control_image
|
52 |
+
elif lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
|
53 |
+
controlnet_args["control_image"] = (
|
54 |
+
lcm_diffusion_setting.controlnet._control_image
|
55 |
+
)
|
56 |
+
return controlnet_args
|
57 |
+
|
58 |
+
|
59 |
+
# Helper function to adjust ControlNet settings from a dictionary
|
60 |
+
def controlnet_settings_from_dict(
|
61 |
+
lcm_diffusion_setting,
|
62 |
+
dictionary,
|
63 |
+
) -> None:
|
64 |
+
if lcm_diffusion_setting is None or dictionary is None:
|
65 |
+
logging.error("Invalid arguments!")
|
66 |
+
return
|
67 |
+
if (
|
68 |
+
"controlnet" not in dictionary
|
69 |
+
or dictionary["controlnet"] is None
|
70 |
+
or len(dictionary["controlnet"]) == 0
|
71 |
+
):
|
72 |
+
logging.warning("ControlNet settings not found, ControlNet will be disabled")
|
73 |
+
lcm_diffusion_setting.controlnet = None
|
74 |
+
return
|
75 |
+
|
76 |
+
controlnet = ControlNetSetting()
|
77 |
+
controlnet.enabled = dictionary["controlnet"][0]["enabled"]
|
78 |
+
controlnet.conditioning_scale = dictionary["controlnet"][0]["conditioning_scale"]
|
79 |
+
controlnet.adapter_path = dictionary["controlnet"][0]["adapter_path"]
|
80 |
+
controlnet._control_image = None
|
81 |
+
image_path = dictionary["controlnet"][0]["control_image"]
|
82 |
+
if controlnet.enabled:
|
83 |
+
try:
|
84 |
+
controlnet._control_image = Image.open(image_path)
|
85 |
+
except (AttributeError, FileNotFoundError) as err:
|
86 |
+
print(err)
|
87 |
+
if controlnet._control_image is None:
|
88 |
+
logging.error("Wrong ControlNet control image! Disabling ControlNet")
|
89 |
+
controlnet.enabled = False
|
90 |
+
lcm_diffusion_setting.controlnet = controlnet
|
backend/device.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
from constants import DEVICE
|
3 |
+
import torch
|
4 |
+
import openvino as ov
|
5 |
+
|
6 |
+
core = ov.Core()
|
7 |
+
|
8 |
+
|
9 |
+
def is_openvino_device() -> bool:
|
10 |
+
if DEVICE.lower() == "cpu" or DEVICE.lower()[0] == "g" or DEVICE.lower()[0] == "n":
|
11 |
+
return True
|
12 |
+
else:
|
13 |
+
return False
|
14 |
+
|
15 |
+
|
16 |
+
def get_device_name() -> str:
|
17 |
+
if DEVICE == "cuda" or DEVICE == "mps":
|
18 |
+
default_gpu_index = torch.cuda.current_device()
|
19 |
+
return torch.cuda.get_device_name(default_gpu_index)
|
20 |
+
elif platform.system().lower() == "darwin":
|
21 |
+
return platform.processor()
|
22 |
+
elif is_openvino_device():
|
23 |
+
return core.get_property(DEVICE.upper(), "FULL_DEVICE_NAME")
|
backend/image_saver.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from os import path, mkdir
|
3 |
+
from typing import Any
|
4 |
+
from uuid import uuid4
|
5 |
+
from backend.models.lcmdiffusion_setting import LCMDiffusionSetting
|
6 |
+
from utils import get_image_file_extension
|
7 |
+
|
8 |
+
|
9 |
+
def get_exclude_keys():
|
10 |
+
exclude_keys = {
|
11 |
+
"init_image": True,
|
12 |
+
"generated_images": True,
|
13 |
+
"lora": {
|
14 |
+
"models_dir": True,
|
15 |
+
"path": True,
|
16 |
+
},
|
17 |
+
"dirs": True,
|
18 |
+
"controlnet": {
|
19 |
+
"adapter_path": True,
|
20 |
+
},
|
21 |
+
}
|
22 |
+
return exclude_keys
|
23 |
+
|
24 |
+
|
25 |
+
class ImageSaver:
|
26 |
+
@staticmethod
|
27 |
+
def save_images(
|
28 |
+
output_path: str,
|
29 |
+
images: Any,
|
30 |
+
folder_name: str = "",
|
31 |
+
format: str = "PNG",
|
32 |
+
lcm_diffusion_setting: LCMDiffusionSetting = None,
|
33 |
+
) -> None:
|
34 |
+
gen_id = uuid4()
|
35 |
+
|
36 |
+
for index, image in enumerate(images):
|
37 |
+
if not path.exists(output_path):
|
38 |
+
mkdir(output_path)
|
39 |
+
|
40 |
+
if folder_name:
|
41 |
+
out_path = path.join(
|
42 |
+
output_path,
|
43 |
+
folder_name,
|
44 |
+
)
|
45 |
+
else:
|
46 |
+
out_path = output_path
|
47 |
+
|
48 |
+
if not path.exists(out_path):
|
49 |
+
mkdir(out_path)
|
50 |
+
image_extension = get_image_file_extension(format)
|
51 |
+
image.save(path.join(out_path, f"{gen_id}-{index+1}{image_extension}"))
|
52 |
+
if lcm_diffusion_setting:
|
53 |
+
with open(path.join(out_path, f"{gen_id}.json"), "w") as json_file:
|
54 |
+
json.dump(
|
55 |
+
lcm_diffusion_setting.model_dump(
|
56 |
+
exclude=get_exclude_keys(),
|
57 |
+
),
|
58 |
+
json_file,
|
59 |
+
indent=4,
|
60 |
+
)
|
backend/lcm_text_to_image.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from math import ceil
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import logging
|
8 |
+
from backend.device import is_openvino_device
|
9 |
+
from backend.lora import load_lora_weight
|
10 |
+
from backend.controlnet import (
|
11 |
+
load_controlnet_adapters,
|
12 |
+
update_controlnet_arguments,
|
13 |
+
)
|
14 |
+
from backend.models.lcmdiffusion_setting import (
|
15 |
+
DiffusionTask,
|
16 |
+
LCMDiffusionSetting,
|
17 |
+
LCMLora,
|
18 |
+
)
|
19 |
+
from backend.openvino.pipelines import (
|
20 |
+
get_ov_image_to_image_pipeline,
|
21 |
+
get_ov_text_to_image_pipeline,
|
22 |
+
ov_load_taesd,
|
23 |
+
)
|
24 |
+
from backend.pipelines.lcm import (
|
25 |
+
get_image_to_image_pipeline,
|
26 |
+
get_lcm_model_pipeline,
|
27 |
+
load_taesd,
|
28 |
+
)
|
29 |
+
from backend.pipelines.lcm_lora import get_lcm_lora_pipeline
|
30 |
+
from constants import DEVICE
|
31 |
+
from diffusers import LCMScheduler
|
32 |
+
from image_ops import resize_pil_image
|
33 |
+
|
34 |
+
|
35 |
+
class LCMTextToImage:
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
device: str = "cpu",
|
39 |
+
) -> None:
|
40 |
+
self.pipeline = None
|
41 |
+
self.use_openvino = False
|
42 |
+
self.device = ""
|
43 |
+
self.previous_model_id = None
|
44 |
+
self.previous_use_tae_sd = False
|
45 |
+
self.previous_use_lcm_lora = False
|
46 |
+
self.previous_ov_model_id = ""
|
47 |
+
self.previous_safety_checker = False
|
48 |
+
self.previous_use_openvino = False
|
49 |
+
self.img_to_img_pipeline = None
|
50 |
+
self.is_openvino_init = False
|
51 |
+
self.previous_lora = None
|
52 |
+
self.task_type = DiffusionTask.text_to_image
|
53 |
+
self.torch_data_type = (
|
54 |
+
torch.float32 if is_openvino_device() or DEVICE == "mps" else torch.float16
|
55 |
+
)
|
56 |
+
print(f"Torch datatype : {self.torch_data_type}")
|
57 |
+
|
58 |
+
def _pipeline_to_device(self):
|
59 |
+
print(f"Pipeline device : {DEVICE}")
|
60 |
+
print(f"Pipeline dtype : {self.torch_data_type}")
|
61 |
+
self.pipeline.to(
|
62 |
+
torch_device=DEVICE,
|
63 |
+
torch_dtype=self.torch_data_type,
|
64 |
+
)
|
65 |
+
|
66 |
+
def _add_freeu(self):
|
67 |
+
pipeline_class = self.pipeline.__class__.__name__
|
68 |
+
if isinstance(self.pipeline.scheduler, LCMScheduler):
|
69 |
+
if pipeline_class == "StableDiffusionPipeline":
|
70 |
+
print("Add FreeU - SD")
|
71 |
+
self.pipeline.enable_freeu(
|
72 |
+
s1=0.9,
|
73 |
+
s2=0.2,
|
74 |
+
b1=1.2,
|
75 |
+
b2=1.4,
|
76 |
+
)
|
77 |
+
elif pipeline_class == "StableDiffusionXLPipeline":
|
78 |
+
print("Add FreeU - SDXL")
|
79 |
+
self.pipeline.enable_freeu(
|
80 |
+
s1=0.6,
|
81 |
+
s2=0.4,
|
82 |
+
b1=1.1,
|
83 |
+
b2=1.2,
|
84 |
+
)
|
85 |
+
|
86 |
+
def _enable_vae_tiling(self):
|
87 |
+
self.pipeline.vae.enable_tiling()
|
88 |
+
|
89 |
+
def _update_lcm_scheduler_params(self):
|
90 |
+
if isinstance(self.pipeline.scheduler, LCMScheduler):
|
91 |
+
self.pipeline.scheduler = LCMScheduler.from_config(
|
92 |
+
self.pipeline.scheduler.config,
|
93 |
+
beta_start=0.001,
|
94 |
+
beta_end=0.01,
|
95 |
+
)
|
96 |
+
|
97 |
+
def init(
|
98 |
+
self,
|
99 |
+
device: str = "cpu",
|
100 |
+
lcm_diffusion_setting: LCMDiffusionSetting = LCMDiffusionSetting(),
|
101 |
+
) -> None:
|
102 |
+
self.device = device
|
103 |
+
self.use_openvino = lcm_diffusion_setting.use_openvino
|
104 |
+
model_id = lcm_diffusion_setting.lcm_model_id
|
105 |
+
use_local_model = lcm_diffusion_setting.use_offline_model
|
106 |
+
use_tiny_auto_encoder = lcm_diffusion_setting.use_tiny_auto_encoder
|
107 |
+
use_lora = lcm_diffusion_setting.use_lcm_lora
|
108 |
+
lcm_lora: LCMLora = lcm_diffusion_setting.lcm_lora
|
109 |
+
ov_model_id = lcm_diffusion_setting.openvino_lcm_model_id
|
110 |
+
|
111 |
+
if lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
|
112 |
+
lcm_diffusion_setting.init_image = resize_pil_image(
|
113 |
+
lcm_diffusion_setting.init_image,
|
114 |
+
lcm_diffusion_setting.image_width,
|
115 |
+
lcm_diffusion_setting.image_height,
|
116 |
+
)
|
117 |
+
|
118 |
+
if (
|
119 |
+
self.pipeline is None
|
120 |
+
or self.previous_model_id != model_id
|
121 |
+
or self.previous_use_tae_sd != use_tiny_auto_encoder
|
122 |
+
or self.previous_lcm_lora_base_id != lcm_lora.base_model_id
|
123 |
+
or self.previous_lcm_lora_id != lcm_lora.lcm_lora_id
|
124 |
+
or self.previous_use_lcm_lora != use_lora
|
125 |
+
or self.previous_ov_model_id != ov_model_id
|
126 |
+
or self.previous_safety_checker != lcm_diffusion_setting.use_safety_checker
|
127 |
+
or self.previous_use_openvino != lcm_diffusion_setting.use_openvino
|
128 |
+
or (
|
129 |
+
self.use_openvino
|
130 |
+
and (
|
131 |
+
self.previous_task_type != lcm_diffusion_setting.diffusion_task
|
132 |
+
or self.previous_lora != lcm_diffusion_setting.lora
|
133 |
+
)
|
134 |
+
)
|
135 |
+
or lcm_diffusion_setting.rebuild_pipeline
|
136 |
+
):
|
137 |
+
if self.use_openvino and is_openvino_device():
|
138 |
+
if self.pipeline:
|
139 |
+
del self.pipeline
|
140 |
+
self.pipeline = None
|
141 |
+
gc.collect()
|
142 |
+
self.is_openvino_init = True
|
143 |
+
if (
|
144 |
+
lcm_diffusion_setting.diffusion_task
|
145 |
+
== DiffusionTask.text_to_image.value
|
146 |
+
):
|
147 |
+
print(f"***** Init Text to image (OpenVINO) - {ov_model_id} *****")
|
148 |
+
self.pipeline = get_ov_text_to_image_pipeline(
|
149 |
+
ov_model_id,
|
150 |
+
use_local_model,
|
151 |
+
)
|
152 |
+
elif (
|
153 |
+
lcm_diffusion_setting.diffusion_task
|
154 |
+
== DiffusionTask.image_to_image.value
|
155 |
+
):
|
156 |
+
print(f"***** Image to image (OpenVINO) - {ov_model_id} *****")
|
157 |
+
self.pipeline = get_ov_image_to_image_pipeline(
|
158 |
+
ov_model_id,
|
159 |
+
use_local_model,
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
if self.pipeline:
|
163 |
+
del self.pipeline
|
164 |
+
self.pipeline = None
|
165 |
+
if self.img_to_img_pipeline:
|
166 |
+
del self.img_to_img_pipeline
|
167 |
+
self.img_to_img_pipeline = None
|
168 |
+
|
169 |
+
controlnet_args = load_controlnet_adapters(lcm_diffusion_setting)
|
170 |
+
if use_lora:
|
171 |
+
print(
|
172 |
+
f"***** Init LCM-LoRA pipeline - {lcm_lora.base_model_id} *****"
|
173 |
+
)
|
174 |
+
self.pipeline = get_lcm_lora_pipeline(
|
175 |
+
lcm_lora.base_model_id,
|
176 |
+
lcm_lora.lcm_lora_id,
|
177 |
+
use_local_model,
|
178 |
+
torch_data_type=self.torch_data_type,
|
179 |
+
pipeline_args=controlnet_args,
|
180 |
+
)
|
181 |
+
|
182 |
+
else:
|
183 |
+
print(f"***** Init LCM Model pipeline - {model_id} *****")
|
184 |
+
self.pipeline = get_lcm_model_pipeline(
|
185 |
+
model_id,
|
186 |
+
use_local_model,
|
187 |
+
controlnet_args,
|
188 |
+
)
|
189 |
+
|
190 |
+
self.img_to_img_pipeline = get_image_to_image_pipeline(self.pipeline)
|
191 |
+
|
192 |
+
if use_tiny_auto_encoder:
|
193 |
+
if self.use_openvino and is_openvino_device():
|
194 |
+
print("Using Tiny Auto Encoder (OpenVINO)")
|
195 |
+
ov_load_taesd(
|
196 |
+
self.pipeline,
|
197 |
+
use_local_model,
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
print("Using Tiny Auto Encoder")
|
201 |
+
load_taesd(
|
202 |
+
self.pipeline,
|
203 |
+
use_local_model,
|
204 |
+
self.torch_data_type,
|
205 |
+
)
|
206 |
+
load_taesd(
|
207 |
+
self.img_to_img_pipeline,
|
208 |
+
use_local_model,
|
209 |
+
self.torch_data_type,
|
210 |
+
)
|
211 |
+
|
212 |
+
if not self.use_openvino and not is_openvino_device():
|
213 |
+
self._pipeline_to_device()
|
214 |
+
|
215 |
+
if (
|
216 |
+
lcm_diffusion_setting.diffusion_task
|
217 |
+
== DiffusionTask.image_to_image.value
|
218 |
+
and lcm_diffusion_setting.use_openvino
|
219 |
+
):
|
220 |
+
self.pipeline.scheduler = LCMScheduler.from_config(
|
221 |
+
self.pipeline.scheduler.config,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
self._update_lcm_scheduler_params()
|
225 |
+
|
226 |
+
if use_lora:
|
227 |
+
self._add_freeu()
|
228 |
+
|
229 |
+
self.previous_model_id = model_id
|
230 |
+
self.previous_ov_model_id = ov_model_id
|
231 |
+
self.previous_use_tae_sd = use_tiny_auto_encoder
|
232 |
+
self.previous_lcm_lora_base_id = lcm_lora.base_model_id
|
233 |
+
self.previous_lcm_lora_id = lcm_lora.lcm_lora_id
|
234 |
+
self.previous_use_lcm_lora = use_lora
|
235 |
+
self.previous_safety_checker = lcm_diffusion_setting.use_safety_checker
|
236 |
+
self.previous_use_openvino = lcm_diffusion_setting.use_openvino
|
237 |
+
self.previous_task_type = lcm_diffusion_setting.diffusion_task
|
238 |
+
self.previous_lora = lcm_diffusion_setting.lora.model_copy(deep=True)
|
239 |
+
lcm_diffusion_setting.rebuild_pipeline = False
|
240 |
+
if (
|
241 |
+
lcm_diffusion_setting.diffusion_task
|
242 |
+
== DiffusionTask.text_to_image.value
|
243 |
+
):
|
244 |
+
print(f"Pipeline : {self.pipeline}")
|
245 |
+
elif (
|
246 |
+
lcm_diffusion_setting.diffusion_task
|
247 |
+
== DiffusionTask.image_to_image.value
|
248 |
+
):
|
249 |
+
if self.use_openvino and is_openvino_device():
|
250 |
+
print(f"Pipeline : {self.pipeline}")
|
251 |
+
else:
|
252 |
+
print(f"Pipeline : {self.img_to_img_pipeline}")
|
253 |
+
if self.use_openvino:
|
254 |
+
if lcm_diffusion_setting.lora.enabled:
|
255 |
+
print("Warning: Lora models not supported on OpenVINO mode")
|
256 |
+
else:
|
257 |
+
adapters = self.pipeline.get_active_adapters()
|
258 |
+
print(f"Active adapters : {adapters}")
|
259 |
+
|
260 |
+
def _get_timesteps(self):
|
261 |
+
time_steps = self.pipeline.scheduler.config.get("timesteps")
|
262 |
+
time_steps_value = [int(time_steps)] if time_steps else None
|
263 |
+
return time_steps_value
|
264 |
+
|
265 |
+
def generate(
|
266 |
+
self,
|
267 |
+
lcm_diffusion_setting: LCMDiffusionSetting,
|
268 |
+
reshape: bool = False,
|
269 |
+
) -> Any:
|
270 |
+
guidance_scale = lcm_diffusion_setting.guidance_scale
|
271 |
+
img_to_img_inference_steps = lcm_diffusion_setting.inference_steps
|
272 |
+
check_step_value = int(
|
273 |
+
lcm_diffusion_setting.inference_steps * lcm_diffusion_setting.strength
|
274 |
+
)
|
275 |
+
if (
|
276 |
+
lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value
|
277 |
+
and check_step_value < 1
|
278 |
+
):
|
279 |
+
img_to_img_inference_steps = ceil(1 / lcm_diffusion_setting.strength)
|
280 |
+
print(
|
281 |
+
f"Strength: {lcm_diffusion_setting.strength},{img_to_img_inference_steps}"
|
282 |
+
)
|
283 |
+
|
284 |
+
if lcm_diffusion_setting.use_seed:
|
285 |
+
cur_seed = lcm_diffusion_setting.seed
|
286 |
+
if self.use_openvino:
|
287 |
+
np.random.seed(cur_seed)
|
288 |
+
else:
|
289 |
+
torch.manual_seed(cur_seed)
|
290 |
+
|
291 |
+
is_openvino_pipe = lcm_diffusion_setting.use_openvino and is_openvino_device()
|
292 |
+
if is_openvino_pipe:
|
293 |
+
print("Using OpenVINO")
|
294 |
+
if reshape and not self.is_openvino_init:
|
295 |
+
print("Reshape and compile")
|
296 |
+
self.pipeline.reshape(
|
297 |
+
batch_size=-1,
|
298 |
+
height=lcm_diffusion_setting.image_height,
|
299 |
+
width=lcm_diffusion_setting.image_width,
|
300 |
+
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
301 |
+
)
|
302 |
+
self.pipeline.compile()
|
303 |
+
|
304 |
+
if self.is_openvino_init:
|
305 |
+
self.is_openvino_init = False
|
306 |
+
|
307 |
+
if not lcm_diffusion_setting.use_safety_checker:
|
308 |
+
self.pipeline.safety_checker = None
|
309 |
+
if (
|
310 |
+
lcm_diffusion_setting.diffusion_task
|
311 |
+
== DiffusionTask.image_to_image.value
|
312 |
+
and not is_openvino_pipe
|
313 |
+
):
|
314 |
+
self.img_to_img_pipeline.safety_checker = None
|
315 |
+
|
316 |
+
if (
|
317 |
+
not lcm_diffusion_setting.use_lcm_lora
|
318 |
+
and not lcm_diffusion_setting.use_openvino
|
319 |
+
and lcm_diffusion_setting.guidance_scale != 1.0
|
320 |
+
):
|
321 |
+
print("Not using LCM-LoRA so setting guidance_scale 1.0")
|
322 |
+
guidance_scale = 1.0
|
323 |
+
|
324 |
+
controlnet_args = update_controlnet_arguments(lcm_diffusion_setting)
|
325 |
+
if lcm_diffusion_setting.use_openvino:
|
326 |
+
if (
|
327 |
+
lcm_diffusion_setting.diffusion_task
|
328 |
+
== DiffusionTask.text_to_image.value
|
329 |
+
):
|
330 |
+
result_images = self.pipeline(
|
331 |
+
prompt=lcm_diffusion_setting.prompt,
|
332 |
+
negative_prompt=lcm_diffusion_setting.negative_prompt,
|
333 |
+
num_inference_steps=lcm_diffusion_setting.inference_steps,
|
334 |
+
guidance_scale=guidance_scale,
|
335 |
+
width=lcm_diffusion_setting.image_width,
|
336 |
+
height=lcm_diffusion_setting.image_height,
|
337 |
+
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
338 |
+
).images
|
339 |
+
elif (
|
340 |
+
lcm_diffusion_setting.diffusion_task
|
341 |
+
== DiffusionTask.image_to_image.value
|
342 |
+
):
|
343 |
+
result_images = self.pipeline(
|
344 |
+
image=lcm_diffusion_setting.init_image,
|
345 |
+
strength=lcm_diffusion_setting.strength,
|
346 |
+
prompt=lcm_diffusion_setting.prompt,
|
347 |
+
negative_prompt=lcm_diffusion_setting.negative_prompt,
|
348 |
+
num_inference_steps=img_to_img_inference_steps * 3,
|
349 |
+
guidance_scale=guidance_scale,
|
350 |
+
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
351 |
+
).images
|
352 |
+
|
353 |
+
else:
|
354 |
+
if (
|
355 |
+
lcm_diffusion_setting.diffusion_task
|
356 |
+
== DiffusionTask.text_to_image.value
|
357 |
+
):
|
358 |
+
result_images = self.pipeline(
|
359 |
+
prompt=lcm_diffusion_setting.prompt,
|
360 |
+
negative_prompt=lcm_diffusion_setting.negative_prompt,
|
361 |
+
num_inference_steps=lcm_diffusion_setting.inference_steps,
|
362 |
+
guidance_scale=guidance_scale,
|
363 |
+
width=lcm_diffusion_setting.image_width,
|
364 |
+
height=lcm_diffusion_setting.image_height,
|
365 |
+
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
366 |
+
timesteps=self._get_timesteps(),
|
367 |
+
**controlnet_args,
|
368 |
+
).images
|
369 |
+
|
370 |
+
elif (
|
371 |
+
lcm_diffusion_setting.diffusion_task
|
372 |
+
== DiffusionTask.image_to_image.value
|
373 |
+
):
|
374 |
+
result_images = self.img_to_img_pipeline(
|
375 |
+
image=lcm_diffusion_setting.init_image,
|
376 |
+
strength=lcm_diffusion_setting.strength,
|
377 |
+
prompt=lcm_diffusion_setting.prompt,
|
378 |
+
negative_prompt=lcm_diffusion_setting.negative_prompt,
|
379 |
+
num_inference_steps=img_to_img_inference_steps,
|
380 |
+
guidance_scale=guidance_scale,
|
381 |
+
width=lcm_diffusion_setting.image_width,
|
382 |
+
height=lcm_diffusion_setting.image_height,
|
383 |
+
num_images_per_prompt=lcm_diffusion_setting.number_of_images,
|
384 |
+
**controlnet_args,
|
385 |
+
).images
|
386 |
+
return result_images
|
backend/lora.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
from os import path
|
3 |
+
from paths import get_file_name, FastStableDiffusionPaths
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
# A basic class to keep track of the currently loaded LoRAs and
|
8 |
+
# their weights; the diffusers function \c get_active_adapters()
|
9 |
+
# returns a list of adapter names but not their weights so we need
|
10 |
+
# a way to keep track of the current LoRA weights to set whenever
|
11 |
+
# a new LoRA is loaded
|
12 |
+
class _lora_info:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
path: str,
|
16 |
+
weight: float,
|
17 |
+
):
|
18 |
+
self.path = path
|
19 |
+
self.adapter_name = get_file_name(path)
|
20 |
+
self.weight = weight
|
21 |
+
|
22 |
+
def __del__(self):
|
23 |
+
self.path = None
|
24 |
+
self.adapter_name = None
|
25 |
+
|
26 |
+
|
27 |
+
_loaded_loras = []
|
28 |
+
_current_pipeline = None
|
29 |
+
|
30 |
+
|
31 |
+
# This function loads a LoRA from the LoRA path setting, so it's
|
32 |
+
# possible to load multiple LoRAs by calling this function more than
|
33 |
+
# once with a different LoRA path setting; note that if you plan to
|
34 |
+
# load multiple LoRAs and dynamically change their weights, you
|
35 |
+
# might want to set the LoRA fuse option to False
|
36 |
+
def load_lora_weight(
|
37 |
+
pipeline,
|
38 |
+
lcm_diffusion_setting,
|
39 |
+
):
|
40 |
+
if not lcm_diffusion_setting.lora.path:
|
41 |
+
raise Exception("Empty lora model path")
|
42 |
+
|
43 |
+
if not path.exists(lcm_diffusion_setting.lora.path):
|
44 |
+
raise Exception("Lora model path is invalid")
|
45 |
+
|
46 |
+
# If the pipeline has been rebuilt since the last call, remove all
|
47 |
+
# references to previously loaded LoRAs and store the new pipeline
|
48 |
+
global _loaded_loras
|
49 |
+
global _current_pipeline
|
50 |
+
if pipeline != _current_pipeline:
|
51 |
+
for lora in _loaded_loras:
|
52 |
+
del lora
|
53 |
+
del _loaded_loras
|
54 |
+
_loaded_loras = []
|
55 |
+
_current_pipeline = pipeline
|
56 |
+
|
57 |
+
current_lora = _lora_info(
|
58 |
+
lcm_diffusion_setting.lora.path,
|
59 |
+
lcm_diffusion_setting.lora.weight,
|
60 |
+
)
|
61 |
+
_loaded_loras.append(current_lora)
|
62 |
+
|
63 |
+
if lcm_diffusion_setting.lora.enabled:
|
64 |
+
print(f"LoRA adapter name : {current_lora.adapter_name}")
|
65 |
+
pipeline.load_lora_weights(
|
66 |
+
FastStableDiffusionPaths.get_lora_models_path(),
|
67 |
+
weight_name=Path(lcm_diffusion_setting.lora.path).name,
|
68 |
+
local_files_only=True,
|
69 |
+
adapter_name=current_lora.adapter_name,
|
70 |
+
)
|
71 |
+
update_lora_weights(
|
72 |
+
pipeline,
|
73 |
+
lcm_diffusion_setting,
|
74 |
+
)
|
75 |
+
|
76 |
+
if lcm_diffusion_setting.lora.fuse:
|
77 |
+
pipeline.fuse_lora()
|
78 |
+
|
79 |
+
|
80 |
+
def get_lora_models(root_dir: str):
|
81 |
+
lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
|
82 |
+
lora_models_map = {}
|
83 |
+
for file_path in lora_models:
|
84 |
+
lora_name = get_file_name(file_path)
|
85 |
+
if lora_name is not None:
|
86 |
+
lora_models_map[lora_name] = file_path
|
87 |
+
return lora_models_map
|
88 |
+
|
89 |
+
|
90 |
+
# This function returns a list of (adapter_name, weight) tuples for the
|
91 |
+
# currently loaded LoRAs
|
92 |
+
def get_active_lora_weights():
|
93 |
+
active_loras = []
|
94 |
+
for lora_info in _loaded_loras:
|
95 |
+
active_loras.append(
|
96 |
+
(
|
97 |
+
lora_info.adapter_name,
|
98 |
+
lora_info.weight,
|
99 |
+
)
|
100 |
+
)
|
101 |
+
return active_loras
|
102 |
+
|
103 |
+
|
104 |
+
# This function receives a pipeline, an lcm_diffusion_setting object and
|
105 |
+
# an optional list of updated (adapter_name, weight) tuples
|
106 |
+
def update_lora_weights(
|
107 |
+
pipeline,
|
108 |
+
lcm_diffusion_setting,
|
109 |
+
lora_weights=None,
|
110 |
+
):
|
111 |
+
global _loaded_loras
|
112 |
+
global _current_pipeline
|
113 |
+
if pipeline != _current_pipeline:
|
114 |
+
print("Wrong pipeline when trying to update LoRA weights")
|
115 |
+
return
|
116 |
+
if lora_weights:
|
117 |
+
for idx, lora in enumerate(lora_weights):
|
118 |
+
if _loaded_loras[idx].adapter_name != lora[0]:
|
119 |
+
print("Wrong adapter name in LoRA enumeration!")
|
120 |
+
continue
|
121 |
+
_loaded_loras[idx].weight = lora[1]
|
122 |
+
|
123 |
+
adapter_names = []
|
124 |
+
adapter_weights = []
|
125 |
+
if lcm_diffusion_setting.use_lcm_lora:
|
126 |
+
adapter_names.append("lcm")
|
127 |
+
adapter_weights.append(1.0)
|
128 |
+
for lora in _loaded_loras:
|
129 |
+
adapter_names.append(lora.adapter_name)
|
130 |
+
adapter_weights.append(lora.weight)
|
131 |
+
pipeline.set_adapters(
|
132 |
+
adapter_names,
|
133 |
+
adapter_weights=adapter_weights,
|
134 |
+
)
|
135 |
+
adapter_weights = zip(adapter_names, adapter_weights)
|
136 |
+
print(f"Adapters: {list(adapter_weights)}")
|
backend/models/device.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class DeviceInfo(BaseModel):
|
5 |
+
device_type: str
|
6 |
+
device_name: str
|
7 |
+
os: str
|
8 |
+
platform: str
|
9 |
+
processor: str
|
backend/models/gen_images.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from enum import Enum, auto
|
3 |
+
from paths import FastStableDiffusionPaths
|
4 |
+
|
5 |
+
|
6 |
+
class ImageFormat(str, Enum):
|
7 |
+
"""Image format"""
|
8 |
+
|
9 |
+
JPEG = "jpeg"
|
10 |
+
PNG = "png"
|
11 |
+
|
12 |
+
|
13 |
+
class GeneratedImages(BaseModel):
|
14 |
+
path: str = FastStableDiffusionPaths.get_results_path()
|
15 |
+
format: str = ImageFormat.PNG.value.upper()
|
16 |
+
save_image: bool = True
|
backend/models/lcmdiffusion_setting.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from PIL import Image
|
3 |
+
from typing import Any, Optional, Union
|
4 |
+
|
5 |
+
from constants import LCM_DEFAULT_MODEL, LCM_DEFAULT_MODEL_OPENVINO
|
6 |
+
from paths import FastStableDiffusionPaths
|
7 |
+
from pydantic import BaseModel
|
8 |
+
|
9 |
+
|
10 |
+
class LCMLora(BaseModel):
|
11 |
+
base_model_id: str = "Lykon/dreamshaper-8"
|
12 |
+
lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
|
13 |
+
|
14 |
+
|
15 |
+
class DiffusionTask(str, Enum):
|
16 |
+
"""Diffusion task types"""
|
17 |
+
|
18 |
+
text_to_image = "text_to_image"
|
19 |
+
image_to_image = "image_to_image"
|
20 |
+
|
21 |
+
|
22 |
+
class Lora(BaseModel):
|
23 |
+
models_dir: str = FastStableDiffusionPaths.get_lora_models_path()
|
24 |
+
path: Optional[Any] = None
|
25 |
+
weight: Optional[float] = 0.5
|
26 |
+
fuse: bool = True
|
27 |
+
enabled: bool = False
|
28 |
+
|
29 |
+
|
30 |
+
class ControlNetSetting(BaseModel):
|
31 |
+
adapter_path: Optional[str] = None # ControlNet adapter path
|
32 |
+
conditioning_scale: float = 0.5
|
33 |
+
enabled: bool = False
|
34 |
+
_control_image: Image = None # Control image, PIL image
|
35 |
+
|
36 |
+
|
37 |
+
class LCMDiffusionSetting(BaseModel):
|
38 |
+
lcm_model_id: str = LCM_DEFAULT_MODEL
|
39 |
+
openvino_lcm_model_id: str = LCM_DEFAULT_MODEL_OPENVINO
|
40 |
+
use_offline_model: bool = False
|
41 |
+
use_lcm_lora: bool = False
|
42 |
+
lcm_lora: Optional[LCMLora] = LCMLora()
|
43 |
+
use_tiny_auto_encoder: bool = False
|
44 |
+
use_openvino: bool = False
|
45 |
+
prompt: str = ""
|
46 |
+
negative_prompt: str = ""
|
47 |
+
init_image: Any = None
|
48 |
+
strength: Optional[float] = 0.6
|
49 |
+
image_height: Optional[int] = 512
|
50 |
+
image_width: Optional[int] = 512
|
51 |
+
inference_steps: Optional[int] = 1
|
52 |
+
guidance_scale: Optional[float] = 1
|
53 |
+
number_of_images: Optional[int] = 1
|
54 |
+
seed: Optional[int] = 123123
|
55 |
+
use_seed: bool = False
|
56 |
+
use_safety_checker: bool = False
|
57 |
+
diffusion_task: str = DiffusionTask.text_to_image.value
|
58 |
+
lora: Optional[Lora] = Lora()
|
59 |
+
controlnet: Optional[Union[ControlNetSetting, list[ControlNetSetting]]] = None
|
60 |
+
dirs: dict = {
|
61 |
+
"controlnet": FastStableDiffusionPaths.get_controlnet_models_path(),
|
62 |
+
"lora": FastStableDiffusionPaths.get_lora_models_path(),
|
63 |
+
}
|
64 |
+
rebuild_pipeline: bool = False
|
backend/models/upscale.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class UpscaleMode(str, Enum):
|
5 |
+
"""Diffusion task types"""
|
6 |
+
|
7 |
+
normal = "normal"
|
8 |
+
sd_upscale = "sd_upscale"
|
9 |
+
aura_sr = "aura_sr"
|
backend/openvino/custom_ov_model_vae_decoder.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.device import is_openvino_device
|
2 |
+
|
3 |
+
if is_openvino_device():
|
4 |
+
from optimum.intel.openvino.modeling_diffusion import OVModelVaeDecoder
|
5 |
+
|
6 |
+
|
7 |
+
class CustomOVModelVaeDecoder(OVModelVaeDecoder):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
model,
|
11 |
+
parent_model,
|
12 |
+
ov_config=None,
|
13 |
+
model_dir=None,
|
14 |
+
):
|
15 |
+
super(OVModelVaeDecoder, self).__init__(
|
16 |
+
model,
|
17 |
+
parent_model,
|
18 |
+
ov_config,
|
19 |
+
"vae_decoder",
|
20 |
+
model_dir,
|
21 |
+
)
|
backend/openvino/pipelines.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import DEVICE, LCM_DEFAULT_MODEL_OPENVINO
|
2 |
+
from backend.tiny_decoder import get_tiny_decoder_vae_model
|
3 |
+
from typing import Any
|
4 |
+
from backend.device import is_openvino_device
|
5 |
+
from paths import get_base_folder_name
|
6 |
+
|
7 |
+
if is_openvino_device():
|
8 |
+
from huggingface_hub import snapshot_download
|
9 |
+
from optimum.intel.openvino.modeling_diffusion import OVBaseModel
|
10 |
+
|
11 |
+
from optimum.intel.openvino.modeling_diffusion import (
|
12 |
+
OVStableDiffusionPipeline,
|
13 |
+
OVStableDiffusionImg2ImgPipeline,
|
14 |
+
OVStableDiffusionXLPipeline,
|
15 |
+
OVStableDiffusionXLImg2ImgPipeline,
|
16 |
+
)
|
17 |
+
from backend.openvino.custom_ov_model_vae_decoder import CustomOVModelVaeDecoder
|
18 |
+
|
19 |
+
|
20 |
+
def ov_load_taesd(
|
21 |
+
pipeline: Any,
|
22 |
+
use_local_model: bool = False,
|
23 |
+
):
|
24 |
+
taesd_dir = snapshot_download(
|
25 |
+
repo_id=get_tiny_decoder_vae_model(pipeline.__class__.__name__),
|
26 |
+
local_files_only=use_local_model,
|
27 |
+
)
|
28 |
+
pipeline.vae_decoder = CustomOVModelVaeDecoder(
|
29 |
+
model=OVBaseModel.load_model(f"{taesd_dir}/vae_decoder/openvino_model.xml"),
|
30 |
+
parent_model=pipeline,
|
31 |
+
model_dir=taesd_dir,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def get_ov_text_to_image_pipeline(
|
36 |
+
model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
|
37 |
+
use_local_model: bool = False,
|
38 |
+
) -> Any:
|
39 |
+
if "xl" in get_base_folder_name(model_id).lower():
|
40 |
+
pipeline = OVStableDiffusionXLPipeline.from_pretrained(
|
41 |
+
model_id,
|
42 |
+
local_files_only=use_local_model,
|
43 |
+
ov_config={"CACHE_DIR": ""},
|
44 |
+
device=DEVICE.upper(),
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
pipeline = OVStableDiffusionPipeline.from_pretrained(
|
48 |
+
model_id,
|
49 |
+
local_files_only=use_local_model,
|
50 |
+
ov_config={"CACHE_DIR": ""},
|
51 |
+
device=DEVICE.upper(),
|
52 |
+
)
|
53 |
+
|
54 |
+
return pipeline
|
55 |
+
|
56 |
+
|
57 |
+
def get_ov_image_to_image_pipeline(
|
58 |
+
model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
|
59 |
+
use_local_model: bool = False,
|
60 |
+
) -> Any:
|
61 |
+
if "xl" in get_base_folder_name(model_id).lower():
|
62 |
+
pipeline = OVStableDiffusionXLImg2ImgPipeline.from_pretrained(
|
63 |
+
model_id,
|
64 |
+
local_files_only=use_local_model,
|
65 |
+
ov_config={"CACHE_DIR": ""},
|
66 |
+
device=DEVICE.upper(),
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
pipeline = OVStableDiffusionImg2ImgPipeline.from_pretrained(
|
70 |
+
model_id,
|
71 |
+
local_files_only=use_local_model,
|
72 |
+
ov_config={"CACHE_DIR": ""},
|
73 |
+
device=DEVICE.upper(),
|
74 |
+
)
|
75 |
+
return pipeline
|
backend/pipelines/lcm.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import LCM_DEFAULT_MODEL
|
2 |
+
from diffusers import (
|
3 |
+
DiffusionPipeline,
|
4 |
+
AutoencoderTiny,
|
5 |
+
UNet2DConditionModel,
|
6 |
+
LCMScheduler,
|
7 |
+
)
|
8 |
+
import torch
|
9 |
+
from backend.tiny_decoder import get_tiny_decoder_vae_model
|
10 |
+
from typing import Any
|
11 |
+
from diffusers import (
|
12 |
+
LCMScheduler,
|
13 |
+
StableDiffusionImg2ImgPipeline,
|
14 |
+
StableDiffusionXLImg2ImgPipeline,
|
15 |
+
AutoPipelineForText2Image,
|
16 |
+
AutoPipelineForImage2Image,
|
17 |
+
StableDiffusionControlNetPipeline,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def _get_lcm_pipeline_from_base_model(
|
22 |
+
lcm_model_id: str,
|
23 |
+
base_model_id: str,
|
24 |
+
use_local_model: bool,
|
25 |
+
):
|
26 |
+
pipeline = None
|
27 |
+
unet = UNet2DConditionModel.from_pretrained(
|
28 |
+
lcm_model_id,
|
29 |
+
torch_dtype=torch.float32,
|
30 |
+
local_files_only=use_local_model,
|
31 |
+
resume_download=True,
|
32 |
+
)
|
33 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
34 |
+
base_model_id,
|
35 |
+
unet=unet,
|
36 |
+
torch_dtype=torch.float32,
|
37 |
+
local_files_only=use_local_model,
|
38 |
+
resume_download=True,
|
39 |
+
)
|
40 |
+
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
41 |
+
return pipeline
|
42 |
+
|
43 |
+
|
44 |
+
def load_taesd(
|
45 |
+
pipeline: Any,
|
46 |
+
use_local_model: bool = False,
|
47 |
+
torch_data_type: torch.dtype = torch.float32,
|
48 |
+
):
|
49 |
+
vae_model = get_tiny_decoder_vae_model(pipeline.__class__.__name__)
|
50 |
+
pipeline.vae = AutoencoderTiny.from_pretrained(
|
51 |
+
vae_model,
|
52 |
+
torch_dtype=torch_data_type,
|
53 |
+
local_files_only=use_local_model,
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def get_lcm_model_pipeline(
|
58 |
+
model_id: str = LCM_DEFAULT_MODEL,
|
59 |
+
use_local_model: bool = False,
|
60 |
+
pipeline_args={},
|
61 |
+
):
|
62 |
+
pipeline = None
|
63 |
+
if model_id == "latent-consistency/lcm-sdxl":
|
64 |
+
pipeline = _get_lcm_pipeline_from_base_model(
|
65 |
+
model_id,
|
66 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
67 |
+
use_local_model,
|
68 |
+
)
|
69 |
+
|
70 |
+
elif model_id == "latent-consistency/lcm-ssd-1b":
|
71 |
+
pipeline = _get_lcm_pipeline_from_base_model(
|
72 |
+
model_id,
|
73 |
+
"segmind/SSD-1B",
|
74 |
+
use_local_model,
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
# pipeline = DiffusionPipeline.from_pretrained(
|
78 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
79 |
+
model_id,
|
80 |
+
local_files_only=use_local_model,
|
81 |
+
**pipeline_args,
|
82 |
+
)
|
83 |
+
|
84 |
+
return pipeline
|
85 |
+
|
86 |
+
|
87 |
+
def get_image_to_image_pipeline(pipeline: Any) -> Any:
|
88 |
+
components = pipeline.components
|
89 |
+
pipeline_class = pipeline.__class__.__name__
|
90 |
+
if (
|
91 |
+
pipeline_class == "LatentConsistencyModelPipeline"
|
92 |
+
or pipeline_class == "StableDiffusionPipeline"
|
93 |
+
):
|
94 |
+
return StableDiffusionImg2ImgPipeline(**components)
|
95 |
+
elif pipeline_class == "StableDiffusionControlNetPipeline":
|
96 |
+
return AutoPipelineForImage2Image.from_pipe(pipeline)
|
97 |
+
elif pipeline_class == "StableDiffusionXLPipeline":
|
98 |
+
return StableDiffusionXLImg2ImgPipeline(**components)
|
99 |
+
else:
|
100 |
+
raise Exception(f"Unknown pipeline {pipeline_class}")
|
backend/pipelines/lcm_lora.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from os import path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import (
|
6 |
+
AutoPipelineForText2Image,
|
7 |
+
LCMScheduler,
|
8 |
+
StableDiffusionPipeline,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def load_lcm_weights(
|
13 |
+
pipeline,
|
14 |
+
use_local_model,
|
15 |
+
lcm_lora_id,
|
16 |
+
):
|
17 |
+
kwargs = {
|
18 |
+
"local_files_only": use_local_model,
|
19 |
+
"weight_name": "pytorch_lora_weights.safetensors",
|
20 |
+
}
|
21 |
+
pipeline.load_lora_weights(
|
22 |
+
lcm_lora_id,
|
23 |
+
**kwargs,
|
24 |
+
adapter_name="lcm",
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def get_lcm_lora_pipeline(
|
29 |
+
base_model_id: str,
|
30 |
+
lcm_lora_id: str,
|
31 |
+
use_local_model: bool,
|
32 |
+
torch_data_type: torch.dtype,
|
33 |
+
pipeline_args={},
|
34 |
+
):
|
35 |
+
if pathlib.Path(base_model_id).suffix == ".safetensors":
|
36 |
+
# SD 1.5 models only
|
37 |
+
# When loading a .safetensors model, the pipeline has to be created
|
38 |
+
# with StableDiffusionPipeline() since it's the only class that
|
39 |
+
# defines the method from_single_file(); afterwards a new pipeline
|
40 |
+
# is created using AutoPipelineForText2Image() for ControlNet
|
41 |
+
# support, in case ControlNet is enabled
|
42 |
+
if not path.exists(base_model_id):
|
43 |
+
raise FileNotFoundError(
|
44 |
+
f"Model file not found,Please check your model path: {base_model_id}"
|
45 |
+
)
|
46 |
+
print("Using single file Safetensors model (Supported models - SD 1.5 models)")
|
47 |
+
|
48 |
+
dummy_pipeline = StableDiffusionPipeline.from_single_file(
|
49 |
+
base_model_id,
|
50 |
+
torch_dtype=torch_data_type,
|
51 |
+
safety_checker=None,
|
52 |
+
load_safety_checker=False,
|
53 |
+
local_files_only=use_local_model,
|
54 |
+
use_safetensors=True,
|
55 |
+
)
|
56 |
+
pipeline = AutoPipelineForText2Image.from_pipe(
|
57 |
+
dummy_pipeline,
|
58 |
+
**pipeline_args,
|
59 |
+
)
|
60 |
+
del dummy_pipeline
|
61 |
+
else:
|
62 |
+
pipeline = AutoPipelineForText2Image.from_pretrained(
|
63 |
+
base_model_id,
|
64 |
+
torch_dtype=torch_data_type,
|
65 |
+
local_files_only=use_local_model,
|
66 |
+
**pipeline_args,
|
67 |
+
)
|
68 |
+
|
69 |
+
load_lcm_weights(
|
70 |
+
pipeline,
|
71 |
+
use_local_model,
|
72 |
+
lcm_lora_id,
|
73 |
+
)
|
74 |
+
# Always fuse LCM-LoRA
|
75 |
+
pipeline.fuse_lora()
|
76 |
+
|
77 |
+
if "lcm" in lcm_lora_id.lower() or "hypersd" in lcm_lora_id.lower():
|
78 |
+
print("LCM LoRA model detected so using recommended LCMScheduler")
|
79 |
+
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
|
80 |
+
|
81 |
+
# pipeline.unet.to(memory_format=torch.channels_last)
|
82 |
+
return pipeline
|
backend/tiny_decoder.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from constants import (
|
2 |
+
TAESD_MODEL,
|
3 |
+
TAESDXL_MODEL,
|
4 |
+
TAESD_MODEL_OPENVINO,
|
5 |
+
TAESDXL_MODEL_OPENVINO,
|
6 |
+
)
|
7 |
+
|
8 |
+
|
9 |
+
def get_tiny_decoder_vae_model(pipeline_class) -> str:
|
10 |
+
print(f"Pipeline class : {pipeline_class}")
|
11 |
+
if (
|
12 |
+
pipeline_class == "LatentConsistencyModelPipeline"
|
13 |
+
or pipeline_class == "StableDiffusionPipeline"
|
14 |
+
or pipeline_class == "StableDiffusionImg2ImgPipeline"
|
15 |
+
or pipeline_class == "StableDiffusionControlNetPipeline"
|
16 |
+
or pipeline_class == "StableDiffusionControlNetImg2ImgPipeline"
|
17 |
+
):
|
18 |
+
return TAESD_MODEL
|
19 |
+
elif (
|
20 |
+
pipeline_class == "StableDiffusionXLPipeline"
|
21 |
+
or pipeline_class == "StableDiffusionXLImg2ImgPipeline"
|
22 |
+
):
|
23 |
+
return TAESDXL_MODEL
|
24 |
+
elif (
|
25 |
+
pipeline_class == "OVStableDiffusionPipeline"
|
26 |
+
or pipeline_class == "OVStableDiffusionImg2ImgPipeline"
|
27 |
+
):
|
28 |
+
return TAESD_MODEL_OPENVINO
|
29 |
+
elif pipeline_class == "OVStableDiffusionXLPipeline":
|
30 |
+
return TAESDXL_MODEL_OPENVINO
|
31 |
+
else:
|
32 |
+
raise Exception("No valid pipeline class found!")
|
backend/upscale/aura_sr.py
ADDED
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is
|
2 |
+
# based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there.
|
3 |
+
#
|
4 |
+
# https://mingukkang.github.io/GigaGAN/
|
5 |
+
from math import log2, ceil
|
6 |
+
from functools import partial
|
7 |
+
from typing import Any, Optional, List, Iterable
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torchvision import transforms
|
11 |
+
from PIL import Image
|
12 |
+
from torch import nn, einsum, Tensor
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from einops import rearrange, repeat, reduce
|
16 |
+
from einops.layers.torch import Rearrange
|
17 |
+
|
18 |
+
|
19 |
+
def get_same_padding(size, kernel, dilation, stride):
|
20 |
+
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
21 |
+
|
22 |
+
|
23 |
+
class AdaptiveConv2DMod(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
dim,
|
27 |
+
dim_out,
|
28 |
+
kernel,
|
29 |
+
*,
|
30 |
+
demod=True,
|
31 |
+
stride=1,
|
32 |
+
dilation=1,
|
33 |
+
eps=1e-8,
|
34 |
+
num_conv_kernels=1, # set this to be greater than 1 for adaptive
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.eps = eps
|
38 |
+
|
39 |
+
self.dim_out = dim_out
|
40 |
+
|
41 |
+
self.kernel = kernel
|
42 |
+
self.stride = stride
|
43 |
+
self.dilation = dilation
|
44 |
+
self.adaptive = num_conv_kernels > 1
|
45 |
+
|
46 |
+
self.weights = nn.Parameter(
|
47 |
+
torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel))
|
48 |
+
)
|
49 |
+
|
50 |
+
self.demod = demod
|
51 |
+
|
52 |
+
nn.init.kaiming_normal_(
|
53 |
+
self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu"
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
notation
|
61 |
+
|
62 |
+
b - batch
|
63 |
+
n - convs
|
64 |
+
o - output
|
65 |
+
i - input
|
66 |
+
k - kernel
|
67 |
+
"""
|
68 |
+
|
69 |
+
b, h = fmap.shape[0], fmap.shape[-2]
|
70 |
+
|
71 |
+
# account for feature map that has been expanded by the scale in the first dimension
|
72 |
+
# due to multiscale inputs and outputs
|
73 |
+
|
74 |
+
if mod.shape[0] != b:
|
75 |
+
mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0])
|
76 |
+
|
77 |
+
if exists(kernel_mod):
|
78 |
+
kernel_mod_has_el = kernel_mod.numel() > 0
|
79 |
+
|
80 |
+
assert self.adaptive or not kernel_mod_has_el
|
81 |
+
|
82 |
+
if kernel_mod_has_el and kernel_mod.shape[0] != b:
|
83 |
+
kernel_mod = repeat(
|
84 |
+
kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0]
|
85 |
+
)
|
86 |
+
|
87 |
+
# prepare weights for modulation
|
88 |
+
|
89 |
+
weights = self.weights
|
90 |
+
|
91 |
+
if self.adaptive:
|
92 |
+
weights = repeat(weights, "... -> b ...", b=b)
|
93 |
+
|
94 |
+
# determine an adaptive weight and 'select' the kernel to use with softmax
|
95 |
+
|
96 |
+
assert exists(kernel_mod) and kernel_mod.numel() > 0
|
97 |
+
|
98 |
+
kernel_attn = kernel_mod.softmax(dim=-1)
|
99 |
+
kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1")
|
100 |
+
|
101 |
+
weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum")
|
102 |
+
|
103 |
+
# do the modulation, demodulation, as done in stylegan2
|
104 |
+
|
105 |
+
mod = rearrange(mod, "b i -> b 1 i 1 1")
|
106 |
+
|
107 |
+
weights = weights * (mod + 1)
|
108 |
+
|
109 |
+
if self.demod:
|
110 |
+
inv_norm = (
|
111 |
+
reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum")
|
112 |
+
.clamp(min=self.eps)
|
113 |
+
.rsqrt()
|
114 |
+
)
|
115 |
+
weights = weights * inv_norm
|
116 |
+
|
117 |
+
fmap = rearrange(fmap, "b c h w -> 1 (b c) h w")
|
118 |
+
|
119 |
+
weights = rearrange(weights, "b o ... -> (b o) ...")
|
120 |
+
|
121 |
+
padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
|
122 |
+
fmap = F.conv2d(fmap, weights, padding=padding, groups=b)
|
123 |
+
|
124 |
+
return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
|
125 |
+
|
126 |
+
|
127 |
+
class Attend(nn.Module):
|
128 |
+
def __init__(self, dropout=0.0, flash=False):
|
129 |
+
super().__init__()
|
130 |
+
self.dropout = dropout
|
131 |
+
self.attn_dropout = nn.Dropout(dropout)
|
132 |
+
self.scale = nn.Parameter(torch.randn(1))
|
133 |
+
self.flash = flash
|
134 |
+
|
135 |
+
def flash_attn(self, q, k, v):
|
136 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
137 |
+
out = F.scaled_dot_product_attention(
|
138 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0
|
139 |
+
)
|
140 |
+
return out
|
141 |
+
|
142 |
+
def forward(self, q, k, v):
|
143 |
+
if self.flash:
|
144 |
+
return self.flash_attn(q, k, v)
|
145 |
+
|
146 |
+
scale = q.shape[-1] ** -0.5
|
147 |
+
|
148 |
+
# similarity
|
149 |
+
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
|
150 |
+
|
151 |
+
# attention
|
152 |
+
attn = sim.softmax(dim=-1)
|
153 |
+
attn = self.attn_dropout(attn)
|
154 |
+
|
155 |
+
# aggregate values
|
156 |
+
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
157 |
+
|
158 |
+
return out
|
159 |
+
|
160 |
+
|
161 |
+
def exists(x):
|
162 |
+
return x is not None
|
163 |
+
|
164 |
+
|
165 |
+
def default(val, d):
|
166 |
+
if exists(val):
|
167 |
+
return val
|
168 |
+
return d() if callable(d) else d
|
169 |
+
|
170 |
+
|
171 |
+
def cast_tuple(t, length=1):
|
172 |
+
if isinstance(t, tuple):
|
173 |
+
return t
|
174 |
+
return (t,) * length
|
175 |
+
|
176 |
+
|
177 |
+
def identity(t, *args, **kwargs):
|
178 |
+
return t
|
179 |
+
|
180 |
+
|
181 |
+
def is_power_of_two(n):
|
182 |
+
return log2(n).is_integer()
|
183 |
+
|
184 |
+
|
185 |
+
def null_iterator():
|
186 |
+
while True:
|
187 |
+
yield None
|
188 |
+
|
189 |
+
def Downsample(dim, dim_out=None):
|
190 |
+
return nn.Sequential(
|
191 |
+
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
192 |
+
nn.Conv2d(dim * 4, default(dim_out, dim), 1),
|
193 |
+
)
|
194 |
+
|
195 |
+
|
196 |
+
class RMSNorm(nn.Module):
|
197 |
+
def __init__(self, dim):
|
198 |
+
super().__init__()
|
199 |
+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
200 |
+
self.eps = 1e-4
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)
|
204 |
+
|
205 |
+
|
206 |
+
# building block modules
|
207 |
+
|
208 |
+
|
209 |
+
class Block(nn.Module):
|
210 |
+
def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0):
|
211 |
+
super().__init__()
|
212 |
+
self.proj = AdaptiveConv2DMod(
|
213 |
+
dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels
|
214 |
+
)
|
215 |
+
self.kernel = 3
|
216 |
+
self.dilation = 1
|
217 |
+
self.stride = 1
|
218 |
+
|
219 |
+
self.act = nn.SiLU()
|
220 |
+
|
221 |
+
def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
|
222 |
+
conv_mods_iter = default(conv_mods_iter, null_iterator())
|
223 |
+
|
224 |
+
x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter))
|
225 |
+
|
226 |
+
x = self.act(x)
|
227 |
+
return x
|
228 |
+
|
229 |
+
|
230 |
+
class ResnetBlock(nn.Module):
|
231 |
+
def __init__(
|
232 |
+
self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = []
|
233 |
+
):
|
234 |
+
super().__init__()
|
235 |
+
style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels])
|
236 |
+
|
237 |
+
self.block1 = Block(
|
238 |
+
dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
|
239 |
+
)
|
240 |
+
self.block2 = Block(
|
241 |
+
dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
|
242 |
+
)
|
243 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
244 |
+
|
245 |
+
def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
|
246 |
+
h = self.block1(x, conv_mods_iter=conv_mods_iter)
|
247 |
+
h = self.block2(h, conv_mods_iter=conv_mods_iter)
|
248 |
+
|
249 |
+
return h + self.res_conv(x)
|
250 |
+
|
251 |
+
|
252 |
+
class LinearAttention(nn.Module):
|
253 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
254 |
+
super().__init__()
|
255 |
+
self.scale = dim_head**-0.5
|
256 |
+
self.heads = heads
|
257 |
+
hidden_dim = dim_head * heads
|
258 |
+
|
259 |
+
self.norm = RMSNorm(dim)
|
260 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
261 |
+
|
262 |
+
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
b, c, h, w = x.shape
|
266 |
+
|
267 |
+
x = self.norm(x)
|
268 |
+
|
269 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
270 |
+
q, k, v = map(
|
271 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
272 |
+
)
|
273 |
+
|
274 |
+
q = q.softmax(dim=-2)
|
275 |
+
k = k.softmax(dim=-1)
|
276 |
+
|
277 |
+
q = q * self.scale
|
278 |
+
|
279 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
280 |
+
|
281 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
282 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
283 |
+
return self.to_out(out)
|
284 |
+
|
285 |
+
|
286 |
+
class Attention(nn.Module):
|
287 |
+
def __init__(self, dim, heads=4, dim_head=32, flash=False):
|
288 |
+
super().__init__()
|
289 |
+
self.heads = heads
|
290 |
+
hidden_dim = dim_head * heads
|
291 |
+
|
292 |
+
self.norm = RMSNorm(dim)
|
293 |
+
|
294 |
+
self.attend = Attend(flash=flash)
|
295 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
296 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
297 |
+
|
298 |
+
def forward(self, x):
|
299 |
+
b, c, h, w = x.shape
|
300 |
+
x = self.norm(x)
|
301 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
302 |
+
|
303 |
+
q, k, v = map(
|
304 |
+
lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv
|
305 |
+
)
|
306 |
+
|
307 |
+
out = self.attend(q, k, v)
|
308 |
+
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
309 |
+
|
310 |
+
return self.to_out(out)
|
311 |
+
|
312 |
+
|
313 |
+
# feedforward
|
314 |
+
def FeedForward(dim, mult=4):
|
315 |
+
return nn.Sequential(
|
316 |
+
RMSNorm(dim),
|
317 |
+
nn.Conv2d(dim, dim * mult, 1),
|
318 |
+
nn.GELU(),
|
319 |
+
nn.Conv2d(dim * mult, dim, 1),
|
320 |
+
)
|
321 |
+
|
322 |
+
|
323 |
+
# transformers
|
324 |
+
class Transformer(nn.Module):
|
325 |
+
def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4):
|
326 |
+
super().__init__()
|
327 |
+
self.layers = nn.ModuleList([])
|
328 |
+
|
329 |
+
for _ in range(depth):
|
330 |
+
self.layers.append(
|
331 |
+
nn.ModuleList(
|
332 |
+
[
|
333 |
+
Attention(
|
334 |
+
dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn
|
335 |
+
),
|
336 |
+
FeedForward(dim=dim, mult=ff_mult),
|
337 |
+
]
|
338 |
+
)
|
339 |
+
)
|
340 |
+
|
341 |
+
def forward(self, x):
|
342 |
+
for attn, ff in self.layers:
|
343 |
+
x = attn(x) + x
|
344 |
+
x = ff(x) + x
|
345 |
+
|
346 |
+
return x
|
347 |
+
|
348 |
+
|
349 |
+
class LinearTransformer(nn.Module):
|
350 |
+
def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4):
|
351 |
+
super().__init__()
|
352 |
+
self.layers = nn.ModuleList([])
|
353 |
+
|
354 |
+
for _ in range(depth):
|
355 |
+
self.layers.append(
|
356 |
+
nn.ModuleList(
|
357 |
+
[
|
358 |
+
LinearAttention(dim=dim, dim_head=dim_head, heads=heads),
|
359 |
+
FeedForward(dim=dim, mult=ff_mult),
|
360 |
+
]
|
361 |
+
)
|
362 |
+
)
|
363 |
+
|
364 |
+
def forward(self, x):
|
365 |
+
for attn, ff in self.layers:
|
366 |
+
x = attn(x) + x
|
367 |
+
x = ff(x) + x
|
368 |
+
|
369 |
+
return x
|
370 |
+
|
371 |
+
|
372 |
+
class NearestNeighborhoodUpsample(nn.Module):
|
373 |
+
def __init__(self, dim, dim_out=None):
|
374 |
+
super().__init__()
|
375 |
+
dim_out = default(dim_out, dim)
|
376 |
+
self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
|
377 |
+
|
378 |
+
def forward(self, x):
|
379 |
+
|
380 |
+
if x.shape[0] >= 64:
|
381 |
+
x = x.contiguous()
|
382 |
+
|
383 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
384 |
+
x = self.conv(x)
|
385 |
+
|
386 |
+
return x
|
387 |
+
|
388 |
+
class EqualLinear(nn.Module):
|
389 |
+
def __init__(self, dim, dim_out, lr_mul=1, bias=True):
|
390 |
+
super().__init__()
|
391 |
+
self.weight = nn.Parameter(torch.randn(dim_out, dim))
|
392 |
+
if bias:
|
393 |
+
self.bias = nn.Parameter(torch.zeros(dim_out))
|
394 |
+
|
395 |
+
self.lr_mul = lr_mul
|
396 |
+
|
397 |
+
def forward(self, input):
|
398 |
+
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
399 |
+
|
400 |
+
|
401 |
+
class StyleGanNetwork(nn.Module):
|
402 |
+
def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0):
|
403 |
+
super().__init__()
|
404 |
+
self.dim_in = dim_in
|
405 |
+
self.dim_out = dim_out
|
406 |
+
self.dim_text_latent = dim_text_latent
|
407 |
+
|
408 |
+
layers = []
|
409 |
+
for i in range(depth):
|
410 |
+
is_first = i == 0
|
411 |
+
|
412 |
+
if is_first:
|
413 |
+
dim_in_layer = dim_in + dim_text_latent
|
414 |
+
else:
|
415 |
+
dim_in_layer = dim_out
|
416 |
+
|
417 |
+
dim_out_layer = dim_out
|
418 |
+
|
419 |
+
layers.extend(
|
420 |
+
[EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)]
|
421 |
+
)
|
422 |
+
|
423 |
+
self.net = nn.Sequential(*layers)
|
424 |
+
|
425 |
+
def forward(self, x, text_latent=None):
|
426 |
+
x = F.normalize(x, dim=1)
|
427 |
+
if self.dim_text_latent > 0:
|
428 |
+
assert exists(text_latent)
|
429 |
+
x = torch.cat((x, text_latent), dim=-1)
|
430 |
+
return self.net(x)
|
431 |
+
|
432 |
+
|
433 |
+
class UnetUpsampler(torch.nn.Module):
|
434 |
+
|
435 |
+
def __init__(
|
436 |
+
self,
|
437 |
+
dim: int,
|
438 |
+
*,
|
439 |
+
image_size: int,
|
440 |
+
input_image_size: int,
|
441 |
+
init_dim: Optional[int] = None,
|
442 |
+
out_dim: Optional[int] = None,
|
443 |
+
style_network: Optional[dict] = None,
|
444 |
+
up_dim_mults: tuple = (1, 2, 4, 8, 16),
|
445 |
+
down_dim_mults: tuple = (4, 8, 16),
|
446 |
+
channels: int = 3,
|
447 |
+
resnet_block_groups: int = 8,
|
448 |
+
full_attn: tuple = (False, False, False, True, True),
|
449 |
+
flash_attn: bool = True,
|
450 |
+
self_attn_dim_head: int = 64,
|
451 |
+
self_attn_heads: int = 8,
|
452 |
+
attn_depths: tuple = (2, 2, 2, 2, 4),
|
453 |
+
mid_attn_depth: int = 4,
|
454 |
+
num_conv_kernels: int = 4,
|
455 |
+
resize_mode: str = "bilinear",
|
456 |
+
unconditional: bool = True,
|
457 |
+
skip_connect_scale: Optional[float] = None,
|
458 |
+
):
|
459 |
+
super().__init__()
|
460 |
+
self.style_network = style_network = StyleGanNetwork(**style_network)
|
461 |
+
self.unconditional = unconditional
|
462 |
+
assert not (
|
463 |
+
unconditional
|
464 |
+
and exists(style_network)
|
465 |
+
and style_network.dim_text_latent > 0
|
466 |
+
)
|
467 |
+
|
468 |
+
assert is_power_of_two(image_size) and is_power_of_two(
|
469 |
+
input_image_size
|
470 |
+
), "both output image size and input image size must be power of 2"
|
471 |
+
assert (
|
472 |
+
input_image_size < image_size
|
473 |
+
), "input image size must be smaller than the output image size, thus upsampling"
|
474 |
+
|
475 |
+
self.image_size = image_size
|
476 |
+
self.input_image_size = input_image_size
|
477 |
+
|
478 |
+
style_embed_split_dims = []
|
479 |
+
|
480 |
+
self.channels = channels
|
481 |
+
input_channels = channels
|
482 |
+
|
483 |
+
init_dim = default(init_dim, dim)
|
484 |
+
|
485 |
+
up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)]
|
486 |
+
init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)]
|
487 |
+
down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)]
|
488 |
+
self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3)
|
489 |
+
|
490 |
+
up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
|
491 |
+
down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
|
492 |
+
|
493 |
+
block_klass = partial(
|
494 |
+
ResnetBlock,
|
495 |
+
groups=resnet_block_groups,
|
496 |
+
num_conv_kernels=num_conv_kernels,
|
497 |
+
style_dims=style_embed_split_dims,
|
498 |
+
)
|
499 |
+
|
500 |
+
FullAttention = partial(Transformer, flash_attn=flash_attn)
|
501 |
+
*_, mid_dim = up_dims
|
502 |
+
|
503 |
+
self.skip_connect_scale = default(skip_connect_scale, 2**-0.5)
|
504 |
+
|
505 |
+
self.downs = nn.ModuleList([])
|
506 |
+
self.ups = nn.ModuleList([])
|
507 |
+
|
508 |
+
block_count = 6
|
509 |
+
|
510 |
+
for ind, (
|
511 |
+
(dim_in, dim_out),
|
512 |
+
layer_full_attn,
|
513 |
+
layer_attn_depth,
|
514 |
+
) in enumerate(zip(down_in_out, full_attn, attn_depths)):
|
515 |
+
attn_klass = FullAttention if layer_full_attn else LinearTransformer
|
516 |
+
|
517 |
+
blocks = []
|
518 |
+
for i in range(block_count):
|
519 |
+
blocks.append(block_klass(dim_in, dim_in))
|
520 |
+
|
521 |
+
self.downs.append(
|
522 |
+
nn.ModuleList(
|
523 |
+
[
|
524 |
+
nn.ModuleList(blocks),
|
525 |
+
nn.ModuleList(
|
526 |
+
[
|
527 |
+
(
|
528 |
+
attn_klass(
|
529 |
+
dim_in,
|
530 |
+
dim_head=self_attn_dim_head,
|
531 |
+
heads=self_attn_heads,
|
532 |
+
depth=layer_attn_depth,
|
533 |
+
)
|
534 |
+
if layer_full_attn
|
535 |
+
else None
|
536 |
+
),
|
537 |
+
nn.Conv2d(
|
538 |
+
dim_in, dim_out, kernel_size=3, stride=2, padding=1
|
539 |
+
),
|
540 |
+
]
|
541 |
+
),
|
542 |
+
]
|
543 |
+
)
|
544 |
+
)
|
545 |
+
|
546 |
+
self.mid_block1 = block_klass(mid_dim, mid_dim)
|
547 |
+
self.mid_attn = FullAttention(
|
548 |
+
mid_dim,
|
549 |
+
dim_head=self_attn_dim_head,
|
550 |
+
heads=self_attn_heads,
|
551 |
+
depth=mid_attn_depth,
|
552 |
+
)
|
553 |
+
self.mid_block2 = block_klass(mid_dim, mid_dim)
|
554 |
+
|
555 |
+
*_, last_dim = up_dims
|
556 |
+
|
557 |
+
for ind, (
|
558 |
+
(dim_in, dim_out),
|
559 |
+
layer_full_attn,
|
560 |
+
layer_attn_depth,
|
561 |
+
) in enumerate(
|
562 |
+
zip(
|
563 |
+
reversed(up_in_out),
|
564 |
+
reversed(full_attn),
|
565 |
+
reversed(attn_depths),
|
566 |
+
)
|
567 |
+
):
|
568 |
+
attn_klass = FullAttention if layer_full_attn else LinearTransformer
|
569 |
+
|
570 |
+
blocks = []
|
571 |
+
input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in
|
572 |
+
for i in range(block_count):
|
573 |
+
blocks.append(block_klass(input_dim, dim_in))
|
574 |
+
|
575 |
+
self.ups.append(
|
576 |
+
nn.ModuleList(
|
577 |
+
[
|
578 |
+
nn.ModuleList(blocks),
|
579 |
+
nn.ModuleList(
|
580 |
+
[
|
581 |
+
NearestNeighborhoodUpsample(
|
582 |
+
last_dim if ind == 0 else dim_out,
|
583 |
+
dim_in,
|
584 |
+
),
|
585 |
+
(
|
586 |
+
attn_klass(
|
587 |
+
dim_in,
|
588 |
+
dim_head=self_attn_dim_head,
|
589 |
+
heads=self_attn_heads,
|
590 |
+
depth=layer_attn_depth,
|
591 |
+
)
|
592 |
+
if layer_full_attn
|
593 |
+
else None
|
594 |
+
),
|
595 |
+
]
|
596 |
+
),
|
597 |
+
]
|
598 |
+
)
|
599 |
+
)
|
600 |
+
|
601 |
+
self.out_dim = default(out_dim, channels)
|
602 |
+
self.final_res_block = block_klass(dim, dim)
|
603 |
+
self.final_to_rgb = nn.Conv2d(dim, channels, 1)
|
604 |
+
self.resize_mode = resize_mode
|
605 |
+
self.style_to_conv_modulations = nn.Linear(
|
606 |
+
style_network.dim_out, sum(style_embed_split_dims)
|
607 |
+
)
|
608 |
+
self.style_embed_split_dims = style_embed_split_dims
|
609 |
+
|
610 |
+
@property
|
611 |
+
def allowable_rgb_resolutions(self):
|
612 |
+
input_res_base = int(log2(self.input_image_size))
|
613 |
+
output_res_base = int(log2(self.image_size))
|
614 |
+
allowed_rgb_res_base = list(range(input_res_base, output_res_base))
|
615 |
+
return [*map(lambda p: 2**p, allowed_rgb_res_base)]
|
616 |
+
|
617 |
+
@property
|
618 |
+
def device(self):
|
619 |
+
return next(self.parameters()).device
|
620 |
+
|
621 |
+
@property
|
622 |
+
def total_params(self):
|
623 |
+
return sum([p.numel() for p in self.parameters()])
|
624 |
+
|
625 |
+
def resize_image_to(self, x, size):
|
626 |
+
return F.interpolate(x, (size, size), mode=self.resize_mode)
|
627 |
+
|
628 |
+
def forward(
|
629 |
+
self,
|
630 |
+
lowres_image: torch.Tensor,
|
631 |
+
styles: Optional[torch.Tensor] = None,
|
632 |
+
noise: Optional[torch.Tensor] = None,
|
633 |
+
global_text_tokens: Optional[torch.Tensor] = None,
|
634 |
+
return_all_rgbs: bool = False,
|
635 |
+
):
|
636 |
+
x = lowres_image
|
637 |
+
|
638 |
+
noise_scale = 0.001 # Adjust the scale of the noise as needed
|
639 |
+
noise_aug = torch.randn_like(x) * noise_scale
|
640 |
+
x = x + noise_aug
|
641 |
+
x = x.clamp(0, 1)
|
642 |
+
|
643 |
+
shape = x.shape
|
644 |
+
batch_size = shape[0]
|
645 |
+
|
646 |
+
assert shape[-2:] == ((self.input_image_size,) * 2)
|
647 |
+
|
648 |
+
# styles
|
649 |
+
if not exists(styles):
|
650 |
+
assert exists(self.style_network)
|
651 |
+
|
652 |
+
noise = default(
|
653 |
+
noise,
|
654 |
+
torch.randn(
|
655 |
+
(batch_size, self.style_network.dim_in), device=self.device
|
656 |
+
),
|
657 |
+
)
|
658 |
+
styles = self.style_network(noise, global_text_tokens)
|
659 |
+
|
660 |
+
# project styles to conv modulations
|
661 |
+
conv_mods = self.style_to_conv_modulations(styles)
|
662 |
+
conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1)
|
663 |
+
conv_mods = iter(conv_mods)
|
664 |
+
|
665 |
+
x = self.init_conv(x)
|
666 |
+
|
667 |
+
h = []
|
668 |
+
for blocks, (attn, downsample) in self.downs:
|
669 |
+
for block in blocks:
|
670 |
+
x = block(x, conv_mods_iter=conv_mods)
|
671 |
+
h.append(x)
|
672 |
+
|
673 |
+
if attn is not None:
|
674 |
+
x = attn(x)
|
675 |
+
|
676 |
+
x = downsample(x)
|
677 |
+
|
678 |
+
x = self.mid_block1(x, conv_mods_iter=conv_mods)
|
679 |
+
x = self.mid_attn(x)
|
680 |
+
x = self.mid_block2(x, conv_mods_iter=conv_mods)
|
681 |
+
|
682 |
+
for (
|
683 |
+
blocks,
|
684 |
+
(
|
685 |
+
upsample,
|
686 |
+
attn,
|
687 |
+
),
|
688 |
+
) in self.ups:
|
689 |
+
x = upsample(x)
|
690 |
+
for block in blocks:
|
691 |
+
if h != []:
|
692 |
+
res = h.pop()
|
693 |
+
res = res * self.skip_connect_scale
|
694 |
+
x = torch.cat((x, res), dim=1)
|
695 |
+
|
696 |
+
x = block(x, conv_mods_iter=conv_mods)
|
697 |
+
|
698 |
+
if attn is not None:
|
699 |
+
x = attn(x)
|
700 |
+
|
701 |
+
x = self.final_res_block(x, conv_mods_iter=conv_mods)
|
702 |
+
rgb = self.final_to_rgb(x)
|
703 |
+
|
704 |
+
if not return_all_rgbs:
|
705 |
+
return rgb
|
706 |
+
|
707 |
+
return rgb, []
|
708 |
+
|
709 |
+
|
710 |
+
def tile_image(image, chunk_size=64):
|
711 |
+
c, h, w = image.shape
|
712 |
+
h_chunks = ceil(h / chunk_size)
|
713 |
+
w_chunks = ceil(w / chunk_size)
|
714 |
+
tiles = []
|
715 |
+
for i in range(h_chunks):
|
716 |
+
for j in range(w_chunks):
|
717 |
+
tile = image[:, i * chunk_size:(i + 1) * chunk_size, j * chunk_size:(j + 1) * chunk_size]
|
718 |
+
tiles.append(tile)
|
719 |
+
return tiles, h_chunks, w_chunks
|
720 |
+
|
721 |
+
|
722 |
+
def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
|
723 |
+
# Determine the shape of the output tensor
|
724 |
+
c = tiles[0].shape[0]
|
725 |
+
h = h_chunks * chunk_size
|
726 |
+
w = w_chunks * chunk_size
|
727 |
+
|
728 |
+
# Create an empty tensor to hold the merged image
|
729 |
+
merged = torch.zeros((c, h, w), dtype=tiles[0].dtype)
|
730 |
+
|
731 |
+
# Iterate over the tiles and place them in the correct position
|
732 |
+
for idx, tile in enumerate(tiles):
|
733 |
+
i = idx // w_chunks
|
734 |
+
j = idx % w_chunks
|
735 |
+
|
736 |
+
h_start = i * chunk_size
|
737 |
+
w_start = j * chunk_size
|
738 |
+
|
739 |
+
tile_h, tile_w = tile.shape[1:]
|
740 |
+
merged[:, h_start:h_start+tile_h, w_start:w_start+tile_w] = tile
|
741 |
+
|
742 |
+
return merged
|
743 |
+
|
744 |
+
|
745 |
+
class AuraSR:
|
746 |
+
def __init__(self, config: dict[str, Any], device: str = "cuda"):
|
747 |
+
self.upsampler = UnetUpsampler(**config).to(device)
|
748 |
+
self.input_image_size = config["input_image_size"]
|
749 |
+
|
750 |
+
@classmethod
|
751 |
+
def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_safetensors: bool = True):
|
752 |
+
import json
|
753 |
+
import torch
|
754 |
+
from pathlib import Path
|
755 |
+
from huggingface_hub import snapshot_download
|
756 |
+
|
757 |
+
# Check if model_id is a local file
|
758 |
+
if Path(model_id).is_file():
|
759 |
+
local_file = Path(model_id)
|
760 |
+
if local_file.suffix == '.safetensors':
|
761 |
+
use_safetensors = True
|
762 |
+
elif local_file.suffix == '.ckpt':
|
763 |
+
use_safetensors = False
|
764 |
+
else:
|
765 |
+
raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.")
|
766 |
+
|
767 |
+
# For local files, we need to provide the config separately
|
768 |
+
config_path = local_file.with_name('config.json')
|
769 |
+
if not config_path.exists():
|
770 |
+
raise FileNotFoundError(
|
771 |
+
f"Config file not found: {config_path}. "
|
772 |
+
f"When loading from a local file, ensure that 'config.json' "
|
773 |
+
f"is present in the same directory as '{local_file.name}'. "
|
774 |
+
f"If you're trying to load a model from Hugging Face, "
|
775 |
+
f"please provide the model ID instead of a file path."
|
776 |
+
)
|
777 |
+
|
778 |
+
config = json.loads(config_path.read_text())
|
779 |
+
hf_model_path = local_file.parent
|
780 |
+
else:
|
781 |
+
hf_model_path = Path(snapshot_download(model_id))
|
782 |
+
config = json.loads((hf_model_path / "config.json").read_text())
|
783 |
+
|
784 |
+
model = cls(config,device)
|
785 |
+
|
786 |
+
if use_safetensors:
|
787 |
+
try:
|
788 |
+
from safetensors.torch import load_file
|
789 |
+
checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id)
|
790 |
+
except ImportError:
|
791 |
+
raise ImportError(
|
792 |
+
"The safetensors library is not installed. "
|
793 |
+
"Please install it with `pip install safetensors` "
|
794 |
+
"or use `use_safetensors=False` to load the model with PyTorch."
|
795 |
+
)
|
796 |
+
else:
|
797 |
+
checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id)
|
798 |
+
|
799 |
+
model.upsampler.load_state_dict(checkpoint, strict=True)
|
800 |
+
return model
|
801 |
+
|
802 |
+
@torch.no_grad()
|
803 |
+
def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image:
|
804 |
+
tensor_transform = transforms.ToTensor()
|
805 |
+
device = self.upsampler.device
|
806 |
+
|
807 |
+
image_tensor = tensor_transform(image).unsqueeze(0)
|
808 |
+
_, _, h, w = image_tensor.shape
|
809 |
+
pad_h = (self.input_image_size - h % self.input_image_size) % self.input_image_size
|
810 |
+
pad_w = (self.input_image_size - w % self.input_image_size) % self.input_image_size
|
811 |
+
|
812 |
+
# Pad the image
|
813 |
+
image_tensor = torch.nn.functional.pad(image_tensor, (0, pad_w, 0, pad_h), mode='reflect').squeeze(0)
|
814 |
+
tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size)
|
815 |
+
|
816 |
+
# Batch processing of tiles
|
817 |
+
num_tiles = len(tiles)
|
818 |
+
batches = [tiles[i:i + max_batch_size] for i in range(0, num_tiles, max_batch_size)]
|
819 |
+
reconstructed_tiles = []
|
820 |
+
|
821 |
+
for batch in batches:
|
822 |
+
model_input = torch.stack(batch).to(device)
|
823 |
+
generator_output = self.upsampler(
|
824 |
+
lowres_image=model_input,
|
825 |
+
noise=torch.randn(model_input.shape[0], 128, device=device)
|
826 |
+
)
|
827 |
+
reconstructed_tiles.extend(list(generator_output.clamp_(0, 1).detach().cpu()))
|
828 |
+
|
829 |
+
merged_tensor = merge_tiles(reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4)
|
830 |
+
unpadded = merged_tensor[:, :h * 4, :w * 4]
|
831 |
+
|
832 |
+
to_pil = transforms.ToPILImage()
|
833 |
+
return to_pil(unpadded)
|
834 |
+
|
backend/upscale/aura_sr_upscale.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.upscale.aura_sr import AuraSR
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
def upscale_aura_sr(image_path: str):
|
6 |
+
|
7 |
+
aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR", device="cpu")
|
8 |
+
image_in = Image.open(image_path) # .resize((256, 256))
|
9 |
+
return aura_sr.upscale_4x(image_in)
|
backend/upscale/edsr_upscale_onnx.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import onnxruntime
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def upscale_edsr_2x(image_path: str):
|
8 |
+
input_image = Image.open(image_path).convert("RGB")
|
9 |
+
input_image = np.array(input_image).astype("float32")
|
10 |
+
input_image = np.transpose(input_image, (2, 0, 1))
|
11 |
+
img_arr = np.expand_dims(input_image, axis=0)
|
12 |
+
|
13 |
+
if np.max(img_arr) > 256: # 16-bit image
|
14 |
+
max_range = 65535
|
15 |
+
else:
|
16 |
+
max_range = 255.0
|
17 |
+
img = img_arr / max_range
|
18 |
+
|
19 |
+
model_path = hf_hub_download(
|
20 |
+
repo_id="rupeshs/edsr-onnx",
|
21 |
+
filename="edsr_onnxsim_2x.onnx",
|
22 |
+
)
|
23 |
+
sess = onnxruntime.InferenceSession(model_path)
|
24 |
+
|
25 |
+
input_name = sess.get_inputs()[0].name
|
26 |
+
output_name = sess.get_outputs()[0].name
|
27 |
+
output = sess.run(
|
28 |
+
[output_name],
|
29 |
+
{input_name: img},
|
30 |
+
)[0]
|
31 |
+
|
32 |
+
result = output.squeeze()
|
33 |
+
result = result.clip(0, 1)
|
34 |
+
image_array = np.transpose(result, (1, 2, 0))
|
35 |
+
image_array = np.uint8(image_array * 255)
|
36 |
+
upscaled_image = Image.fromarray(image_array)
|
37 |
+
return upscaled_image
|
backend/upscale/tiled_upscale.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import math
|
3 |
+
import logging
|
4 |
+
from PIL import Image, ImageDraw, ImageFilter
|
5 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
6 |
+
from context import Context
|
7 |
+
from constants import DEVICE
|
8 |
+
|
9 |
+
|
10 |
+
def generate_upscaled_image(
|
11 |
+
config,
|
12 |
+
input_path=None,
|
13 |
+
strength=0.3,
|
14 |
+
scale_factor=2.0,
|
15 |
+
tile_overlap=16,
|
16 |
+
upscale_settings=None,
|
17 |
+
context: Context = None,
|
18 |
+
output_path=None,
|
19 |
+
image_format="PNG",
|
20 |
+
):
|
21 |
+
if config == None or (
|
22 |
+
input_path == None or input_path == "" and upscale_settings == None
|
23 |
+
):
|
24 |
+
logging.error("Wrong arguments in tiled upscale function call!")
|
25 |
+
return
|
26 |
+
|
27 |
+
# Use the upscale_settings dict if provided; otherwise, build the
|
28 |
+
# upscale_settings dict using the function arguments and default values
|
29 |
+
if upscale_settings == None:
|
30 |
+
upscale_settings = {
|
31 |
+
"source_file": input_path,
|
32 |
+
"target_file": None,
|
33 |
+
"output_format": image_format,
|
34 |
+
"strength": strength,
|
35 |
+
"scale_factor": scale_factor,
|
36 |
+
"prompt": config.lcm_diffusion_setting.prompt,
|
37 |
+
"tile_overlap": tile_overlap,
|
38 |
+
"tile_size": 256,
|
39 |
+
"tiles": [],
|
40 |
+
}
|
41 |
+
source_image = Image.open(input_path) # PIL image
|
42 |
+
else:
|
43 |
+
source_image = Image.open(upscale_settings["source_file"])
|
44 |
+
|
45 |
+
upscale_settings["source_image"] = source_image
|
46 |
+
|
47 |
+
if upscale_settings["target_file"]:
|
48 |
+
result = Image.open(upscale_settings["target_file"])
|
49 |
+
else:
|
50 |
+
result = Image.new(
|
51 |
+
mode="RGBA",
|
52 |
+
size=(
|
53 |
+
source_image.size[0] * int(upscale_settings["scale_factor"]),
|
54 |
+
source_image.size[1] * int(upscale_settings["scale_factor"]),
|
55 |
+
),
|
56 |
+
color=(0, 0, 0, 0),
|
57 |
+
)
|
58 |
+
upscale_settings["target_image"] = result
|
59 |
+
|
60 |
+
# If the custom tile definition array 'tiles' is empty, proceed with the
|
61 |
+
# default tiled upscale task by defining all the possible image tiles; note
|
62 |
+
# that the actual tile size is 'tile_size' + 'tile_overlap' and the target
|
63 |
+
# image width and height are no longer constrained to multiples of 256 but
|
64 |
+
# are instead multiples of the actual tile size
|
65 |
+
if len(upscale_settings["tiles"]) == 0:
|
66 |
+
tile_size = upscale_settings["tile_size"]
|
67 |
+
scale_factor = upscale_settings["scale_factor"]
|
68 |
+
tile_overlap = upscale_settings["tile_overlap"]
|
69 |
+
total_cols = math.ceil(
|
70 |
+
source_image.size[0] / tile_size
|
71 |
+
) # Image width / tile size
|
72 |
+
total_rows = math.ceil(
|
73 |
+
source_image.size[1] / tile_size
|
74 |
+
) # Image height / tile size
|
75 |
+
for y in range(0, total_rows):
|
76 |
+
y_offset = tile_overlap if y > 0 else 0 # Tile mask offset
|
77 |
+
for x in range(0, total_cols):
|
78 |
+
x_offset = tile_overlap if x > 0 else 0 # Tile mask offset
|
79 |
+
x1 = x * tile_size
|
80 |
+
y1 = y * tile_size
|
81 |
+
w = tile_size + (tile_overlap if x < total_cols - 1 else 0)
|
82 |
+
h = tile_size + (tile_overlap if y < total_rows - 1 else 0)
|
83 |
+
mask_box = ( # Default tile mask box definition
|
84 |
+
x_offset,
|
85 |
+
y_offset,
|
86 |
+
int(w * scale_factor),
|
87 |
+
int(h * scale_factor),
|
88 |
+
)
|
89 |
+
upscale_settings["tiles"].append(
|
90 |
+
{
|
91 |
+
"x": x1,
|
92 |
+
"y": y1,
|
93 |
+
"w": w,
|
94 |
+
"h": h,
|
95 |
+
"mask_box": mask_box,
|
96 |
+
"prompt": upscale_settings["prompt"], # Use top level prompt if available
|
97 |
+
"scale_factor": scale_factor,
|
98 |
+
}
|
99 |
+
)
|
100 |
+
|
101 |
+
# Generate the output image tiles
|
102 |
+
for i in range(0, len(upscale_settings["tiles"])):
|
103 |
+
generate_upscaled_tile(
|
104 |
+
config,
|
105 |
+
i,
|
106 |
+
upscale_settings,
|
107 |
+
context=context,
|
108 |
+
)
|
109 |
+
|
110 |
+
# Save completed upscaled image
|
111 |
+
if upscale_settings["output_format"].upper() == "JPEG":
|
112 |
+
result_rgb = result.convert("RGB")
|
113 |
+
result.close()
|
114 |
+
result = result_rgb
|
115 |
+
result.save(output_path)
|
116 |
+
result.close()
|
117 |
+
source_image.close()
|
118 |
+
return
|
119 |
+
|
120 |
+
|
121 |
+
def get_current_tile(
|
122 |
+
config,
|
123 |
+
context,
|
124 |
+
strength,
|
125 |
+
):
|
126 |
+
config.lcm_diffusion_setting.strength = strength
|
127 |
+
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
|
128 |
+
if (
|
129 |
+
config.lcm_diffusion_setting.use_tiny_auto_encoder
|
130 |
+
and config.lcm_diffusion_setting.use_openvino
|
131 |
+
):
|
132 |
+
config.lcm_diffusion_setting.use_tiny_auto_encoder = False
|
133 |
+
current_tile = context.generate_text_to_image(
|
134 |
+
settings=config,
|
135 |
+
reshape=True,
|
136 |
+
device=DEVICE,
|
137 |
+
save_images=False,
|
138 |
+
save_config=False,
|
139 |
+
)[0]
|
140 |
+
return current_tile
|
141 |
+
|
142 |
+
|
143 |
+
# Generates a single tile from the source image as defined in the
|
144 |
+
# upscale_settings["tiles"] array with the corresponding index and pastes the
|
145 |
+
# generated tile into the target image using the corresponding mask and scale
|
146 |
+
# factor; note that scale factor for the target image and the individual tiles
|
147 |
+
# can be different, this function will adjust scale factors as needed
|
148 |
+
def generate_upscaled_tile(
|
149 |
+
config,
|
150 |
+
index,
|
151 |
+
upscale_settings,
|
152 |
+
context: Context = None,
|
153 |
+
):
|
154 |
+
if config == None or upscale_settings == None:
|
155 |
+
logging.error("Wrong arguments in tile creation function call!")
|
156 |
+
return
|
157 |
+
|
158 |
+
x = upscale_settings["tiles"][index]["x"]
|
159 |
+
y = upscale_settings["tiles"][index]["y"]
|
160 |
+
w = upscale_settings["tiles"][index]["w"]
|
161 |
+
h = upscale_settings["tiles"][index]["h"]
|
162 |
+
tile_prompt = upscale_settings["tiles"][index]["prompt"]
|
163 |
+
scale_factor = upscale_settings["scale_factor"]
|
164 |
+
tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
|
165 |
+
target_width = int(w * tile_scale_factor)
|
166 |
+
target_height = int(h * tile_scale_factor)
|
167 |
+
strength = upscale_settings["strength"]
|
168 |
+
source_image = upscale_settings["source_image"]
|
169 |
+
target_image = upscale_settings["target_image"]
|
170 |
+
mask_image = generate_tile_mask(config, index, upscale_settings)
|
171 |
+
|
172 |
+
config.lcm_diffusion_setting.number_of_images = 1
|
173 |
+
config.lcm_diffusion_setting.prompt = tile_prompt
|
174 |
+
config.lcm_diffusion_setting.image_width = target_width
|
175 |
+
config.lcm_diffusion_setting.image_height = target_height
|
176 |
+
config.lcm_diffusion_setting.init_image = source_image.crop((x, y, x + w, y + h))
|
177 |
+
|
178 |
+
current_tile = None
|
179 |
+
print(f"[SD Upscale] Generating tile {index + 1}/{len(upscale_settings['tiles'])} ")
|
180 |
+
if tile_prompt == None or tile_prompt == "":
|
181 |
+
config.lcm_diffusion_setting.prompt = ""
|
182 |
+
config.lcm_diffusion_setting.negative_prompt = ""
|
183 |
+
current_tile = get_current_tile(config, context, strength)
|
184 |
+
else:
|
185 |
+
# Attempt to use img2img with low denoising strength to
|
186 |
+
# generate the tiles with the extra aid of a prompt
|
187 |
+
# context = get_context(InterfaceType.CLI)
|
188 |
+
current_tile = get_current_tile(config, context, strength)
|
189 |
+
|
190 |
+
if math.isclose(scale_factor, tile_scale_factor):
|
191 |
+
target_image.paste(
|
192 |
+
current_tile, (int(x * scale_factor), int(y * scale_factor)), mask_image
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
target_image.paste(
|
196 |
+
current_tile.resize((int(w * scale_factor), int(h * scale_factor))),
|
197 |
+
(int(x * scale_factor), int(y * scale_factor)),
|
198 |
+
mask_image.resize((int(w * scale_factor), int(h * scale_factor))),
|
199 |
+
)
|
200 |
+
mask_image.close()
|
201 |
+
current_tile.close()
|
202 |
+
config.lcm_diffusion_setting.init_image.close()
|
203 |
+
|
204 |
+
|
205 |
+
# Generate tile mask using the box definition in the upscale_settings["tiles"]
|
206 |
+
# array with the corresponding index; note that tile masks for the default
|
207 |
+
# tiled upscale task can be reused but that would complicate the code, so
|
208 |
+
# new tile masks are instead created for each tile
|
209 |
+
def generate_tile_mask(
|
210 |
+
config,
|
211 |
+
index,
|
212 |
+
upscale_settings,
|
213 |
+
):
|
214 |
+
scale_factor = upscale_settings["scale_factor"]
|
215 |
+
tile_overlap = upscale_settings["tile_overlap"]
|
216 |
+
tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
|
217 |
+
w = int(upscale_settings["tiles"][index]["w"] * tile_scale_factor)
|
218 |
+
h = int(upscale_settings["tiles"][index]["h"] * tile_scale_factor)
|
219 |
+
# The Stable Diffusion pipeline automatically adjusts the output size
|
220 |
+
# to multiples of 8 pixels; the mask must be created with the same
|
221 |
+
# size as the output tile
|
222 |
+
w = w - (w % 8)
|
223 |
+
h = h - (h % 8)
|
224 |
+
mask_box = upscale_settings["tiles"][index]["mask_box"]
|
225 |
+
if mask_box == None:
|
226 |
+
# Build a default solid mask with soft/transparent edges
|
227 |
+
mask_box = (
|
228 |
+
tile_overlap,
|
229 |
+
tile_overlap,
|
230 |
+
w - tile_overlap,
|
231 |
+
h - tile_overlap,
|
232 |
+
)
|
233 |
+
mask_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0, 0))
|
234 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
235 |
+
mask_draw.rectangle(tuple(mask_box), fill=(0, 0, 0))
|
236 |
+
mask_blur = mask_image.filter(ImageFilter.BoxBlur(tile_overlap - 1))
|
237 |
+
mask_image.close()
|
238 |
+
return mask_blur
|
backend/upscale/upscaler.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
2 |
+
from backend.models.upscale import UpscaleMode
|
3 |
+
from backend.upscale.edsr_upscale_onnx import upscale_edsr_2x
|
4 |
+
from backend.upscale.aura_sr_upscale import upscale_aura_sr
|
5 |
+
from backend.upscale.tiled_upscale import generate_upscaled_image
|
6 |
+
from context import Context
|
7 |
+
from PIL import Image
|
8 |
+
from state import get_settings
|
9 |
+
|
10 |
+
|
11 |
+
config = get_settings()
|
12 |
+
|
13 |
+
|
14 |
+
def upscale_image(
|
15 |
+
context: Context,
|
16 |
+
src_image_path: str,
|
17 |
+
dst_image_path: str,
|
18 |
+
scale_factor: int = 2,
|
19 |
+
upscale_mode: UpscaleMode = UpscaleMode.normal.value,
|
20 |
+
):
|
21 |
+
if upscale_mode == UpscaleMode.normal.value:
|
22 |
+
|
23 |
+
upscaled_img = upscale_edsr_2x(src_image_path)
|
24 |
+
upscaled_img.save(dst_image_path)
|
25 |
+
print(f"Upscaled image saved {dst_image_path}")
|
26 |
+
elif upscale_mode == UpscaleMode.aura_sr.value:
|
27 |
+
upscaled_img = upscale_aura_sr(src_image_path)
|
28 |
+
upscaled_img.save(dst_image_path)
|
29 |
+
print(f"Upscaled image saved {dst_image_path}")
|
30 |
+
else:
|
31 |
+
config.settings.lcm_diffusion_setting.strength = (
|
32 |
+
0.3 if config.settings.lcm_diffusion_setting.use_openvino else 0.1
|
33 |
+
)
|
34 |
+
config.settings.lcm_diffusion_setting.diffusion_task = (
|
35 |
+
DiffusionTask.image_to_image.value
|
36 |
+
)
|
37 |
+
|
38 |
+
generate_upscaled_image(
|
39 |
+
config.settings,
|
40 |
+
src_image_path,
|
41 |
+
config.settings.lcm_diffusion_setting.strength,
|
42 |
+
upscale_settings=None,
|
43 |
+
context=context,
|
44 |
+
tile_overlap=(
|
45 |
+
32 if config.settings.lcm_diffusion_setting.use_openvino else 16
|
46 |
+
),
|
47 |
+
output_path=dst_image_path,
|
48 |
+
image_format=config.settings.generated_images.format,
|
49 |
+
)
|
50 |
+
print(f"Upscaled image saved {dst_image_path}")
|
51 |
+
|
52 |
+
return [Image.open(dst_image_path)]
|
constants.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import environ
|
2 |
+
|
3 |
+
APP_VERSION = "v1.0.0 beta 33"
|
4 |
+
LCM_DEFAULT_MODEL = "stabilityai/sd-turbo"
|
5 |
+
LCM_DEFAULT_MODEL_OPENVINO = "rupeshs/sd-turbo-openvino"
|
6 |
+
APP_NAME = "FastSD CPU"
|
7 |
+
APP_SETTINGS_FILE = "settings.yaml"
|
8 |
+
RESULTS_DIRECTORY = "results"
|
9 |
+
CONFIG_DIRECTORY = "configs"
|
10 |
+
DEVICE = environ.get("DEVICE", "cpu")
|
11 |
+
SD_MODELS_FILE = "stable-diffusion-models.txt"
|
12 |
+
LCM_LORA_MODELS_FILE = "lcm-lora-models.txt"
|
13 |
+
OPENVINO_LCM_MODELS_FILE = "openvino-lcm-models.txt"
|
14 |
+
TAESD_MODEL = "madebyollin/taesd"
|
15 |
+
TAESDXL_MODEL = "madebyollin/taesdxl"
|
16 |
+
TAESD_MODEL_OPENVINO = "deinferno/taesd-openvino"
|
17 |
+
LCM_MODELS_FILE = "lcm-models.txt"
|
18 |
+
TAESDXL_MODEL_OPENVINO = "rupeshs/taesdxl-openvino"
|
19 |
+
LORA_DIRECTORY = "lora_models"
|
20 |
+
CONTROLNET_DIRECTORY = "controlnet_models"
|
context.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
from app_settings import Settings
|
3 |
+
from models.interface_types import InterfaceType
|
4 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
5 |
+
from backend.lcm_text_to_image import LCMTextToImage
|
6 |
+
from time import perf_counter
|
7 |
+
from backend.image_saver import ImageSaver
|
8 |
+
from pprint import pprint
|
9 |
+
|
10 |
+
|
11 |
+
class Context:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
interface_type: InterfaceType,
|
15 |
+
device="cpu",
|
16 |
+
):
|
17 |
+
self.interface_type = interface_type.value
|
18 |
+
self.lcm_text_to_image = LCMTextToImage(device)
|
19 |
+
self._latency = 0
|
20 |
+
|
21 |
+
@property
|
22 |
+
def latency(self):
|
23 |
+
return self._latency
|
24 |
+
|
25 |
+
def generate_text_to_image(
|
26 |
+
self,
|
27 |
+
settings: Settings,
|
28 |
+
reshape: bool = False,
|
29 |
+
device: str = "cpu",
|
30 |
+
save_images=True,
|
31 |
+
save_config=True,
|
32 |
+
) -> Any:
|
33 |
+
if (
|
34 |
+
settings.lcm_diffusion_setting.use_tiny_auto_encoder
|
35 |
+
and settings.lcm_diffusion_setting.use_openvino
|
36 |
+
):
|
37 |
+
print(
|
38 |
+
"WARNING: Tiny AutoEncoder is not supported in Image to image mode (OpenVINO)"
|
39 |
+
)
|
40 |
+
tick = perf_counter()
|
41 |
+
from state import get_settings
|
42 |
+
|
43 |
+
if (
|
44 |
+
settings.lcm_diffusion_setting.diffusion_task
|
45 |
+
== DiffusionTask.text_to_image.value
|
46 |
+
):
|
47 |
+
settings.lcm_diffusion_setting.init_image = None
|
48 |
+
|
49 |
+
if save_config:
|
50 |
+
get_settings().save()
|
51 |
+
|
52 |
+
pprint(settings.lcm_diffusion_setting.model_dump())
|
53 |
+
if not settings.lcm_diffusion_setting.lcm_lora:
|
54 |
+
return None
|
55 |
+
self.lcm_text_to_image.init(
|
56 |
+
device,
|
57 |
+
settings.lcm_diffusion_setting,
|
58 |
+
)
|
59 |
+
images = self.lcm_text_to_image.generate(
|
60 |
+
settings.lcm_diffusion_setting,
|
61 |
+
reshape,
|
62 |
+
)
|
63 |
+
elapsed = perf_counter() - tick
|
64 |
+
|
65 |
+
if save_images and settings.generated_images.save_image:
|
66 |
+
ImageSaver.save_images(
|
67 |
+
settings.generated_images.path,
|
68 |
+
images=images,
|
69 |
+
lcm_diffusion_setting=settings.lcm_diffusion_setting,
|
70 |
+
format=settings.generated_images.format,
|
71 |
+
)
|
72 |
+
self._latency = elapsed
|
73 |
+
print(f"Latency : {elapsed:.2f} seconds")
|
74 |
+
if settings.lcm_diffusion_setting.controlnet:
|
75 |
+
if settings.lcm_diffusion_setting.controlnet.enabled:
|
76 |
+
images.append(settings.lcm_diffusion_setting.controlnet._control_image)
|
77 |
+
return images
|
frontend/cli_interactive.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path
|
2 |
+
from PIL import Image
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
from constants import DEVICE
|
6 |
+
from paths import FastStableDiffusionPaths
|
7 |
+
from backend.upscale.upscaler import upscale_image
|
8 |
+
from backend.controlnet import controlnet_settings_from_dict
|
9 |
+
from backend.upscale.tiled_upscale import generate_upscaled_image
|
10 |
+
from frontend.webui.image_variations_ui import generate_image_variations
|
11 |
+
from backend.lora import (
|
12 |
+
get_active_lora_weights,
|
13 |
+
update_lora_weights,
|
14 |
+
load_lora_weight,
|
15 |
+
)
|
16 |
+
from backend.models.lcmdiffusion_setting import (
|
17 |
+
DiffusionTask,
|
18 |
+
LCMDiffusionSetting,
|
19 |
+
ControlNetSetting,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
_batch_count = 1
|
24 |
+
_edit_lora_settings = False
|
25 |
+
|
26 |
+
|
27 |
+
def user_value(
|
28 |
+
value_type: type,
|
29 |
+
message: str,
|
30 |
+
default_value: Any,
|
31 |
+
) -> Any:
|
32 |
+
try:
|
33 |
+
value = value_type(input(message))
|
34 |
+
except:
|
35 |
+
value = default_value
|
36 |
+
return value
|
37 |
+
|
38 |
+
|
39 |
+
def interactive_mode(
|
40 |
+
config,
|
41 |
+
context,
|
42 |
+
):
|
43 |
+
print("=============================================")
|
44 |
+
print("Welcome to FastSD CPU Interactive CLI")
|
45 |
+
print("=============================================")
|
46 |
+
while True:
|
47 |
+
print("> 1. Text to Image")
|
48 |
+
print("> 2. Image to Image")
|
49 |
+
print("> 3. Image Variations")
|
50 |
+
print("> 4. EDSR Upscale")
|
51 |
+
print("> 5. SD Upscale")
|
52 |
+
print("> 6. Edit default generation settings")
|
53 |
+
print("> 7. Edit LoRA settings")
|
54 |
+
print("> 8. Edit ControlNet settings")
|
55 |
+
print("> 9. Edit negative prompt")
|
56 |
+
print("> 10. Quit")
|
57 |
+
option = user_value(
|
58 |
+
int,
|
59 |
+
"Enter a Diffusion Task number (1): ",
|
60 |
+
1,
|
61 |
+
)
|
62 |
+
if option not in range(1, 11):
|
63 |
+
print("Wrong Diffusion Task number!")
|
64 |
+
exit()
|
65 |
+
|
66 |
+
if option == 1:
|
67 |
+
interactive_txt2img(
|
68 |
+
config,
|
69 |
+
context,
|
70 |
+
)
|
71 |
+
elif option == 2:
|
72 |
+
interactive_img2img(
|
73 |
+
config,
|
74 |
+
context,
|
75 |
+
)
|
76 |
+
elif option == 3:
|
77 |
+
interactive_variations(
|
78 |
+
config,
|
79 |
+
context,
|
80 |
+
)
|
81 |
+
elif option == 4:
|
82 |
+
interactive_edsr(
|
83 |
+
config,
|
84 |
+
context,
|
85 |
+
)
|
86 |
+
elif option == 5:
|
87 |
+
interactive_sdupscale(
|
88 |
+
config,
|
89 |
+
context,
|
90 |
+
)
|
91 |
+
elif option == 6:
|
92 |
+
interactive_settings(
|
93 |
+
config,
|
94 |
+
context,
|
95 |
+
)
|
96 |
+
elif option == 7:
|
97 |
+
interactive_lora(
|
98 |
+
config,
|
99 |
+
context,
|
100 |
+
True,
|
101 |
+
)
|
102 |
+
elif option == 8:
|
103 |
+
interactive_controlnet(
|
104 |
+
config,
|
105 |
+
context,
|
106 |
+
True,
|
107 |
+
)
|
108 |
+
elif option == 9:
|
109 |
+
interactive_negative(
|
110 |
+
config,
|
111 |
+
context,
|
112 |
+
)
|
113 |
+
elif option == 10:
|
114 |
+
exit()
|
115 |
+
|
116 |
+
|
117 |
+
def interactive_negative(
|
118 |
+
config,
|
119 |
+
context,
|
120 |
+
):
|
121 |
+
settings = config.lcm_diffusion_setting
|
122 |
+
print(f"Current negative prompt: '{settings.negative_prompt}'")
|
123 |
+
user_input = input("Write a negative prompt (set guidance > 1.0): ")
|
124 |
+
if user_input == "":
|
125 |
+
return
|
126 |
+
else:
|
127 |
+
settings.negative_prompt = user_input
|
128 |
+
|
129 |
+
|
130 |
+
def interactive_controlnet(
|
131 |
+
config,
|
132 |
+
context,
|
133 |
+
menu_flag=False,
|
134 |
+
):
|
135 |
+
"""
|
136 |
+
@param menu_flag: Indicates whether this function was called from the main
|
137 |
+
interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
|
138 |
+
"""
|
139 |
+
settings = config.lcm_diffusion_setting
|
140 |
+
if not settings.controlnet:
|
141 |
+
settings.controlnet = ControlNetSetting()
|
142 |
+
|
143 |
+
current_enabled = settings.controlnet.enabled
|
144 |
+
current_adapter_path = settings.controlnet.adapter_path
|
145 |
+
current_conditioning_scale = settings.controlnet.conditioning_scale
|
146 |
+
current_control_image = settings.controlnet._control_image
|
147 |
+
|
148 |
+
option = input("Enable ControlNet? (y/N): ")
|
149 |
+
settings.controlnet.enabled = True if option.upper() == "Y" else False
|
150 |
+
if settings.controlnet.enabled:
|
151 |
+
option = input(
|
152 |
+
f"Enter ControlNet adapter path ({settings.controlnet.adapter_path}): "
|
153 |
+
)
|
154 |
+
if option != "":
|
155 |
+
settings.controlnet.adapter_path = option
|
156 |
+
settings.controlnet.conditioning_scale = user_value(
|
157 |
+
float,
|
158 |
+
f"Enter ControlNet conditioning scale ({settings.controlnet.conditioning_scale}): ",
|
159 |
+
settings.controlnet.conditioning_scale,
|
160 |
+
)
|
161 |
+
option = input(
|
162 |
+
f"Enter ControlNet control image path (Leave empty to reuse current): "
|
163 |
+
)
|
164 |
+
if option != "":
|
165 |
+
try:
|
166 |
+
new_image = Image.open(option)
|
167 |
+
settings.controlnet._control_image = new_image
|
168 |
+
except (AttributeError, FileNotFoundError) as e:
|
169 |
+
settings.controlnet._control_image = None
|
170 |
+
if (
|
171 |
+
not settings.controlnet.adapter_path
|
172 |
+
or not path.exists(settings.controlnet.adapter_path)
|
173 |
+
or not settings.controlnet._control_image
|
174 |
+
):
|
175 |
+
print("Invalid ControlNet settings! Disabling ControlNet")
|
176 |
+
settings.controlnet.enabled = False
|
177 |
+
|
178 |
+
if (
|
179 |
+
settings.controlnet.enabled != current_enabled
|
180 |
+
or settings.controlnet.adapter_path != current_adapter_path
|
181 |
+
):
|
182 |
+
settings.rebuild_pipeline = True
|
183 |
+
|
184 |
+
|
185 |
+
def interactive_lora(
|
186 |
+
config,
|
187 |
+
context,
|
188 |
+
menu_flag=False,
|
189 |
+
):
|
190 |
+
"""
|
191 |
+
@param menu_flag: Indicates whether this function was called from the main
|
192 |
+
interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
|
193 |
+
"""
|
194 |
+
if context == None or context.lcm_text_to_image.pipeline == None:
|
195 |
+
print("Diffusion pipeline not initialized, please run a generation task first!")
|
196 |
+
return
|
197 |
+
|
198 |
+
print("> 1. Change LoRA weights")
|
199 |
+
print("> 2. Load new LoRA model")
|
200 |
+
option = user_value(
|
201 |
+
int,
|
202 |
+
"Enter a LoRA option (1): ",
|
203 |
+
1,
|
204 |
+
)
|
205 |
+
if option not in range(1, 3):
|
206 |
+
print("Wrong LoRA option!")
|
207 |
+
return
|
208 |
+
|
209 |
+
if option == 1:
|
210 |
+
update_weights = []
|
211 |
+
active_weights = get_active_lora_weights()
|
212 |
+
for lora in active_weights:
|
213 |
+
weight = user_value(
|
214 |
+
float,
|
215 |
+
f"Enter a new LoRA weight for {lora[0]} ({lora[1]}): ",
|
216 |
+
lora[1],
|
217 |
+
)
|
218 |
+
update_weights.append(
|
219 |
+
(
|
220 |
+
lora[0],
|
221 |
+
weight,
|
222 |
+
)
|
223 |
+
)
|
224 |
+
if len(update_weights) > 0:
|
225 |
+
update_lora_weights(
|
226 |
+
context.lcm_text_to_image.pipeline,
|
227 |
+
config.lcm_diffusion_setting,
|
228 |
+
update_weights,
|
229 |
+
)
|
230 |
+
elif option == 2:
|
231 |
+
# Load a new LoRA
|
232 |
+
settings = config.lcm_diffusion_setting
|
233 |
+
settings.lora.fuse = False
|
234 |
+
settings.lora.enabled = False
|
235 |
+
settings.lora.path = input("Enter LoRA model path: ")
|
236 |
+
settings.lora.weight = user_value(
|
237 |
+
float,
|
238 |
+
"Enter a LoRA weight (0.5): ",
|
239 |
+
0.5,
|
240 |
+
)
|
241 |
+
if not path.exists(settings.lora.path):
|
242 |
+
print("Invalid LoRA model path!")
|
243 |
+
return
|
244 |
+
settings.lora.enabled = True
|
245 |
+
load_lora_weight(context.lcm_text_to_image.pipeline, settings)
|
246 |
+
|
247 |
+
if menu_flag:
|
248 |
+
global _edit_lora_settings
|
249 |
+
_edit_lora_settings = False
|
250 |
+
option = input("Edit LoRA settings after every generation? (y/N): ")
|
251 |
+
if option.upper() == "Y":
|
252 |
+
_edit_lora_settings = True
|
253 |
+
|
254 |
+
|
255 |
+
def interactive_settings(
|
256 |
+
config,
|
257 |
+
context,
|
258 |
+
):
|
259 |
+
global _batch_count
|
260 |
+
settings = config.lcm_diffusion_setting
|
261 |
+
print("Enter generation settings (leave empty to use current value)")
|
262 |
+
print("> 1. Use LCM")
|
263 |
+
print("> 2. Use LCM-Lora")
|
264 |
+
print("> 3. Use OpenVINO")
|
265 |
+
option = user_value(
|
266 |
+
int,
|
267 |
+
"Select inference model option (1): ",
|
268 |
+
1,
|
269 |
+
)
|
270 |
+
if option not in range(1, 4):
|
271 |
+
print("Wrong inference model option! Falling back to defaults")
|
272 |
+
return
|
273 |
+
|
274 |
+
settings.use_lcm_lora = False
|
275 |
+
settings.use_openvino = False
|
276 |
+
if option == 1:
|
277 |
+
lcm_model_id = input(f"Enter LCM model ID ({settings.lcm_model_id}): ")
|
278 |
+
if lcm_model_id != "":
|
279 |
+
settings.lcm_model_id = lcm_model_id
|
280 |
+
elif option == 2:
|
281 |
+
settings.use_lcm_lora = True
|
282 |
+
lcm_lora_id = input(
|
283 |
+
f"Enter LCM-Lora model ID ({settings.lcm_lora.lcm_lora_id}): "
|
284 |
+
)
|
285 |
+
if lcm_lora_id != "":
|
286 |
+
settings.lcm_lora.lcm_lora_id = lcm_lora_id
|
287 |
+
base_model_id = input(
|
288 |
+
f"Enter Base model ID ({settings.lcm_lora.base_model_id}): "
|
289 |
+
)
|
290 |
+
if base_model_id != "":
|
291 |
+
settings.lcm_lora.base_model_id = base_model_id
|
292 |
+
elif option == 3:
|
293 |
+
settings.use_openvino = True
|
294 |
+
openvino_lcm_model_id = input(
|
295 |
+
f"Enter OpenVINO model ID ({settings.openvino_lcm_model_id}): "
|
296 |
+
)
|
297 |
+
if openvino_lcm_model_id != "":
|
298 |
+
settings.openvino_lcm_model_id = openvino_lcm_model_id
|
299 |
+
|
300 |
+
settings.use_offline_model = True
|
301 |
+
settings.use_tiny_auto_encoder = True
|
302 |
+
option = input("Work offline? (Y/n): ")
|
303 |
+
if option.upper() == "N":
|
304 |
+
settings.use_offline_model = False
|
305 |
+
option = input("Use Tiny Auto Encoder? (Y/n): ")
|
306 |
+
if option.upper() == "N":
|
307 |
+
settings.use_tiny_auto_encoder = False
|
308 |
+
|
309 |
+
settings.image_width = user_value(
|
310 |
+
int,
|
311 |
+
f"Image width ({settings.image_width}): ",
|
312 |
+
settings.image_width,
|
313 |
+
)
|
314 |
+
settings.image_height = user_value(
|
315 |
+
int,
|
316 |
+
f"Image height ({settings.image_height}): ",
|
317 |
+
settings.image_height,
|
318 |
+
)
|
319 |
+
settings.inference_steps = user_value(
|
320 |
+
int,
|
321 |
+
f"Inference steps ({settings.inference_steps}): ",
|
322 |
+
settings.inference_steps,
|
323 |
+
)
|
324 |
+
settings.guidance_scale = user_value(
|
325 |
+
float,
|
326 |
+
f"Guidance scale ({settings.guidance_scale}): ",
|
327 |
+
settings.guidance_scale,
|
328 |
+
)
|
329 |
+
settings.number_of_images = user_value(
|
330 |
+
int,
|
331 |
+
f"Number of images per batch ({settings.number_of_images}): ",
|
332 |
+
settings.number_of_images,
|
333 |
+
)
|
334 |
+
_batch_count = user_value(
|
335 |
+
int,
|
336 |
+
f"Batch count ({_batch_count}): ",
|
337 |
+
_batch_count,
|
338 |
+
)
|
339 |
+
# output_format = user_value(int, f"Output format (PNG)", 1)
|
340 |
+
print(config.lcm_diffusion_setting)
|
341 |
+
|
342 |
+
|
343 |
+
def interactive_txt2img(
|
344 |
+
config,
|
345 |
+
context,
|
346 |
+
):
|
347 |
+
global _batch_count
|
348 |
+
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
|
349 |
+
user_input = input("Write a prompt (write 'exit' to quit): ")
|
350 |
+
while True:
|
351 |
+
if user_input == "exit":
|
352 |
+
return
|
353 |
+
elif user_input == "":
|
354 |
+
user_input = config.lcm_diffusion_setting.prompt
|
355 |
+
config.lcm_diffusion_setting.prompt = user_input
|
356 |
+
for i in range(0, _batch_count):
|
357 |
+
context.generate_text_to_image(
|
358 |
+
settings=config,
|
359 |
+
device=DEVICE,
|
360 |
+
)
|
361 |
+
if _edit_lora_settings:
|
362 |
+
interactive_lora(
|
363 |
+
config,
|
364 |
+
context,
|
365 |
+
)
|
366 |
+
user_input = input("Write a prompt: ")
|
367 |
+
|
368 |
+
|
369 |
+
def interactive_img2img(
|
370 |
+
config,
|
371 |
+
context,
|
372 |
+
):
|
373 |
+
global _batch_count
|
374 |
+
settings = config.lcm_diffusion_setting
|
375 |
+
settings.diffusion_task = DiffusionTask.image_to_image.value
|
376 |
+
steps = settings.inference_steps
|
377 |
+
source_path = input("Image path: ")
|
378 |
+
if source_path == "":
|
379 |
+
print("Error : You need to provide a file in img2img mode")
|
380 |
+
return
|
381 |
+
settings.strength = user_value(
|
382 |
+
float,
|
383 |
+
f"img2img strength ({settings.strength}): ",
|
384 |
+
settings.strength,
|
385 |
+
)
|
386 |
+
settings.inference_steps = int(steps / settings.strength + 1)
|
387 |
+
user_input = input("Write a prompt (write 'exit' to quit): ")
|
388 |
+
while True:
|
389 |
+
if user_input == "exit":
|
390 |
+
settings.inference_steps = steps
|
391 |
+
return
|
392 |
+
settings.init_image = Image.open(source_path)
|
393 |
+
settings.prompt = user_input
|
394 |
+
for i in range(0, _batch_count):
|
395 |
+
context.generate_text_to_image(
|
396 |
+
settings=config,
|
397 |
+
device=DEVICE,
|
398 |
+
)
|
399 |
+
new_path = input(f"Image path ({source_path}): ")
|
400 |
+
if new_path != "":
|
401 |
+
source_path = new_path
|
402 |
+
settings.strength = user_value(
|
403 |
+
float,
|
404 |
+
f"img2img strength ({settings.strength}): ",
|
405 |
+
settings.strength,
|
406 |
+
)
|
407 |
+
if _edit_lora_settings:
|
408 |
+
interactive_lora(
|
409 |
+
config,
|
410 |
+
context,
|
411 |
+
)
|
412 |
+
settings.inference_steps = int(steps / settings.strength + 1)
|
413 |
+
user_input = input("Write a prompt: ")
|
414 |
+
|
415 |
+
|
416 |
+
def interactive_variations(
|
417 |
+
config,
|
418 |
+
context,
|
419 |
+
):
|
420 |
+
global _batch_count
|
421 |
+
settings = config.lcm_diffusion_setting
|
422 |
+
settings.diffusion_task = DiffusionTask.image_to_image.value
|
423 |
+
steps = settings.inference_steps
|
424 |
+
source_path = input("Image path: ")
|
425 |
+
if source_path == "":
|
426 |
+
print("Error : You need to provide a file in Image variations mode")
|
427 |
+
return
|
428 |
+
settings.strength = user_value(
|
429 |
+
float,
|
430 |
+
f"Image variations strength ({settings.strength}): ",
|
431 |
+
settings.strength,
|
432 |
+
)
|
433 |
+
settings.inference_steps = int(steps / settings.strength + 1)
|
434 |
+
while True:
|
435 |
+
settings.init_image = Image.open(source_path)
|
436 |
+
settings.prompt = ""
|
437 |
+
for i in range(0, _batch_count):
|
438 |
+
generate_image_variations(
|
439 |
+
settings.init_image,
|
440 |
+
settings.strength,
|
441 |
+
)
|
442 |
+
if _edit_lora_settings:
|
443 |
+
interactive_lora(
|
444 |
+
config,
|
445 |
+
context,
|
446 |
+
)
|
447 |
+
user_input = input("Continue in Image variations mode? (Y/n): ")
|
448 |
+
if user_input.upper() == "N":
|
449 |
+
settings.inference_steps = steps
|
450 |
+
return
|
451 |
+
new_path = input(f"Image path ({source_path}): ")
|
452 |
+
if new_path != "":
|
453 |
+
source_path = new_path
|
454 |
+
settings.strength = user_value(
|
455 |
+
float,
|
456 |
+
f"Image variations strength ({settings.strength}): ",
|
457 |
+
settings.strength,
|
458 |
+
)
|
459 |
+
settings.inference_steps = int(steps / settings.strength + 1)
|
460 |
+
|
461 |
+
|
462 |
+
def interactive_edsr(
|
463 |
+
config,
|
464 |
+
context,
|
465 |
+
):
|
466 |
+
source_path = input("Image path: ")
|
467 |
+
if source_path == "":
|
468 |
+
print("Error : You need to provide a file in EDSR mode")
|
469 |
+
return
|
470 |
+
while True:
|
471 |
+
output_path = FastStableDiffusionPaths.get_upscale_filepath(
|
472 |
+
source_path,
|
473 |
+
2,
|
474 |
+
config.generated_images.format,
|
475 |
+
)
|
476 |
+
result = upscale_image(
|
477 |
+
context,
|
478 |
+
source_path,
|
479 |
+
output_path,
|
480 |
+
2,
|
481 |
+
)
|
482 |
+
user_input = input("Continue in EDSR upscale mode? (Y/n): ")
|
483 |
+
if user_input.upper() == "N":
|
484 |
+
return
|
485 |
+
new_path = input(f"Image path ({source_path}): ")
|
486 |
+
if new_path != "":
|
487 |
+
source_path = new_path
|
488 |
+
|
489 |
+
|
490 |
+
def interactive_sdupscale_settings(config):
|
491 |
+
steps = config.lcm_diffusion_setting.inference_steps
|
492 |
+
custom_settings = {}
|
493 |
+
print("> 1. Upscale whole image")
|
494 |
+
print("> 2. Define custom tiles (advanced)")
|
495 |
+
option = user_value(
|
496 |
+
int,
|
497 |
+
"Select an SD Upscale option (1): ",
|
498 |
+
1,
|
499 |
+
)
|
500 |
+
if option not in range(1, 3):
|
501 |
+
print("Wrong SD Upscale option!")
|
502 |
+
return
|
503 |
+
|
504 |
+
# custom_settings["source_file"] = args.file
|
505 |
+
custom_settings["source_file"] = ""
|
506 |
+
new_path = input(f"Input image path ({custom_settings['source_file']}): ")
|
507 |
+
if new_path != "":
|
508 |
+
custom_settings["source_file"] = new_path
|
509 |
+
if custom_settings["source_file"] == "":
|
510 |
+
print("Error : You need to provide a file in SD Upscale mode")
|
511 |
+
return
|
512 |
+
custom_settings["target_file"] = None
|
513 |
+
if option == 2:
|
514 |
+
custom_settings["target_file"] = input("Image to patch: ")
|
515 |
+
if custom_settings["target_file"] == "":
|
516 |
+
print("No target file provided, upscaling whole input image instead!")
|
517 |
+
custom_settings["target_file"] = None
|
518 |
+
option = 1
|
519 |
+
custom_settings["output_format"] = config.generated_images.format
|
520 |
+
custom_settings["strength"] = user_value(
|
521 |
+
float,
|
522 |
+
f"SD Upscale strength ({config.lcm_diffusion_setting.strength}): ",
|
523 |
+
config.lcm_diffusion_setting.strength,
|
524 |
+
)
|
525 |
+
config.lcm_diffusion_setting.inference_steps = int(
|
526 |
+
steps / custom_settings["strength"] + 1
|
527 |
+
)
|
528 |
+
if option == 1:
|
529 |
+
custom_settings["scale_factor"] = user_value(
|
530 |
+
float,
|
531 |
+
f"Scale factor (2.0): ",
|
532 |
+
2.0,
|
533 |
+
)
|
534 |
+
custom_settings["tile_size"] = user_value(
|
535 |
+
int,
|
536 |
+
f"Split input image into tiles of the following size, in pixels (256): ",
|
537 |
+
256,
|
538 |
+
)
|
539 |
+
custom_settings["tile_overlap"] = user_value(
|
540 |
+
int,
|
541 |
+
f"Tile overlap, in pixels (16): ",
|
542 |
+
16,
|
543 |
+
)
|
544 |
+
elif option == 2:
|
545 |
+
custom_settings["scale_factor"] = user_value(
|
546 |
+
float,
|
547 |
+
"Input image to Image-to-patch scale_factor (2.0): ",
|
548 |
+
2.0,
|
549 |
+
)
|
550 |
+
custom_settings["tile_size"] = 256
|
551 |
+
custom_settings["tile_overlap"] = 16
|
552 |
+
custom_settings["prompt"] = input(
|
553 |
+
"Write a prompt describing the input image (optional): "
|
554 |
+
)
|
555 |
+
custom_settings["tiles"] = []
|
556 |
+
if option == 2:
|
557 |
+
add_tile = True
|
558 |
+
while add_tile:
|
559 |
+
print("=== Define custom SD Upscale tile ===")
|
560 |
+
tile_x = user_value(
|
561 |
+
int,
|
562 |
+
"Enter tile's X position: ",
|
563 |
+
0,
|
564 |
+
)
|
565 |
+
tile_y = user_value(
|
566 |
+
int,
|
567 |
+
"Enter tile's Y position: ",
|
568 |
+
0,
|
569 |
+
)
|
570 |
+
tile_w = user_value(
|
571 |
+
int,
|
572 |
+
"Enter tile's width (256): ",
|
573 |
+
256,
|
574 |
+
)
|
575 |
+
tile_h = user_value(
|
576 |
+
int,
|
577 |
+
"Enter tile's height (256): ",
|
578 |
+
256,
|
579 |
+
)
|
580 |
+
tile_scale = user_value(
|
581 |
+
float,
|
582 |
+
"Enter tile's scale factor (2.0): ",
|
583 |
+
2.0,
|
584 |
+
)
|
585 |
+
tile_prompt = input("Enter tile's prompt (optional): ")
|
586 |
+
custom_settings["tiles"].append(
|
587 |
+
{
|
588 |
+
"x": tile_x,
|
589 |
+
"y": tile_y,
|
590 |
+
"w": tile_w,
|
591 |
+
"h": tile_h,
|
592 |
+
"mask_box": None,
|
593 |
+
"prompt": tile_prompt,
|
594 |
+
"scale_factor": tile_scale,
|
595 |
+
}
|
596 |
+
)
|
597 |
+
tile_option = input("Do you want to define another tile? (y/N): ")
|
598 |
+
if tile_option == "" or tile_option.upper() == "N":
|
599 |
+
add_tile = False
|
600 |
+
|
601 |
+
return custom_settings
|
602 |
+
|
603 |
+
|
604 |
+
def interactive_sdupscale(
|
605 |
+
config,
|
606 |
+
context,
|
607 |
+
):
|
608 |
+
settings = config.lcm_diffusion_setting
|
609 |
+
settings.diffusion_task = DiffusionTask.image_to_image.value
|
610 |
+
settings.init_image = ""
|
611 |
+
source_path = ""
|
612 |
+
steps = settings.inference_steps
|
613 |
+
|
614 |
+
while True:
|
615 |
+
custom_upscale_settings = None
|
616 |
+
option = input("Edit custom SD Upscale settings? (y/N): ")
|
617 |
+
if option.upper() == "Y":
|
618 |
+
config.lcm_diffusion_setting.inference_steps = steps
|
619 |
+
custom_upscale_settings = interactive_sdupscale_settings(config)
|
620 |
+
if not custom_upscale_settings:
|
621 |
+
return
|
622 |
+
source_path = custom_upscale_settings["source_file"]
|
623 |
+
else:
|
624 |
+
new_path = input(f"Image path ({source_path}): ")
|
625 |
+
if new_path != "":
|
626 |
+
source_path = new_path
|
627 |
+
if source_path == "":
|
628 |
+
print("Error : You need to provide a file in SD Upscale mode")
|
629 |
+
return
|
630 |
+
settings.strength = user_value(
|
631 |
+
float,
|
632 |
+
f"SD Upscale strength ({settings.strength}): ",
|
633 |
+
settings.strength,
|
634 |
+
)
|
635 |
+
settings.inference_steps = int(steps / settings.strength + 1)
|
636 |
+
|
637 |
+
output_path = FastStableDiffusionPaths.get_upscale_filepath(
|
638 |
+
source_path,
|
639 |
+
2,
|
640 |
+
config.generated_images.format,
|
641 |
+
)
|
642 |
+
generate_upscaled_image(
|
643 |
+
config,
|
644 |
+
source_path,
|
645 |
+
settings.strength,
|
646 |
+
upscale_settings=custom_upscale_settings,
|
647 |
+
context=context,
|
648 |
+
tile_overlap=32 if settings.use_openvino else 16,
|
649 |
+
output_path=output_path,
|
650 |
+
image_format=config.generated_images.format,
|
651 |
+
)
|
652 |
+
user_input = input("Continue in SD Upscale mode? (Y/n): ")
|
653 |
+
if user_input.upper() == "N":
|
654 |
+
settings.inference_steps = steps
|
655 |
+
return
|
frontend/gui/app_window.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PyQt5.QtWidgets import (
|
2 |
+
QWidget,
|
3 |
+
QPushButton,
|
4 |
+
QHBoxLayout,
|
5 |
+
QVBoxLayout,
|
6 |
+
QLabel,
|
7 |
+
QLineEdit,
|
8 |
+
QMainWindow,
|
9 |
+
QSlider,
|
10 |
+
QTabWidget,
|
11 |
+
QSpacerItem,
|
12 |
+
QSizePolicy,
|
13 |
+
QComboBox,
|
14 |
+
QCheckBox,
|
15 |
+
QTextEdit,
|
16 |
+
QToolButton,
|
17 |
+
QFileDialog,
|
18 |
+
)
|
19 |
+
from PyQt5 import QtWidgets, QtCore
|
20 |
+
from PyQt5.QtGui import QPixmap, QDesktopServices
|
21 |
+
from PyQt5.QtCore import QSize, QThreadPool, Qt, QUrl
|
22 |
+
|
23 |
+
from PIL.ImageQt import ImageQt
|
24 |
+
from constants import (
|
25 |
+
LCM_DEFAULT_MODEL,
|
26 |
+
LCM_DEFAULT_MODEL_OPENVINO,
|
27 |
+
APP_NAME,
|
28 |
+
APP_VERSION,
|
29 |
+
)
|
30 |
+
from frontend.gui.image_generator_worker import ImageGeneratorWorker
|
31 |
+
from app_settings import AppSettings
|
32 |
+
from paths import FastStableDiffusionPaths
|
33 |
+
from frontend.utils import is_reshape_required
|
34 |
+
from context import Context
|
35 |
+
from models.interface_types import InterfaceType
|
36 |
+
from constants import DEVICE
|
37 |
+
from frontend.utils import enable_openvino_controls, get_valid_model_id
|
38 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
39 |
+
|
40 |
+
# DPI scale fix
|
41 |
+
QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling, True)
|
42 |
+
QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps, True)
|
43 |
+
|
44 |
+
|
45 |
+
class MainWindow(QMainWindow):
|
46 |
+
def __init__(self, config: AppSettings):
|
47 |
+
super().__init__()
|
48 |
+
self.config = config
|
49 |
+
# Prevent saved LoRA and ControlNet settings from being used by
|
50 |
+
# default; in GUI mode, the user must explicitly enable those
|
51 |
+
if self.config.settings.lcm_diffusion_setting.lora:
|
52 |
+
self.config.settings.lcm_diffusion_setting.lora.enabled = False
|
53 |
+
if self.config.settings.lcm_diffusion_setting.controlnet:
|
54 |
+
self.config.settings.lcm_diffusion_setting.controlnet.enabled = False
|
55 |
+
self.setWindowTitle(APP_NAME)
|
56 |
+
self.setFixedSize(QSize(600, 670))
|
57 |
+
self.init_ui()
|
58 |
+
self.pipeline = None
|
59 |
+
self.threadpool = QThreadPool()
|
60 |
+
self.device = "cpu"
|
61 |
+
self.previous_width = 0
|
62 |
+
self.previous_height = 0
|
63 |
+
self.previous_model = ""
|
64 |
+
self.previous_num_of_images = 0
|
65 |
+
self.context = Context(InterfaceType.GUI)
|
66 |
+
self.init_ui_values()
|
67 |
+
self.gen_images = []
|
68 |
+
self.image_index = 0
|
69 |
+
print(f"Output path : { self.config.settings.generated_images.path}")
|
70 |
+
|
71 |
+
def init_ui_values(self):
|
72 |
+
self.lcm_model.setEnabled(
|
73 |
+
not self.config.settings.lcm_diffusion_setting.use_openvino
|
74 |
+
)
|
75 |
+
self.guidance.setValue(
|
76 |
+
int(self.config.settings.lcm_diffusion_setting.guidance_scale * 10)
|
77 |
+
)
|
78 |
+
self.seed_value.setEnabled(self.config.settings.lcm_diffusion_setting.use_seed)
|
79 |
+
self.safety_checker.setChecked(
|
80 |
+
self.config.settings.lcm_diffusion_setting.use_safety_checker
|
81 |
+
)
|
82 |
+
self.use_openvino_check.setChecked(
|
83 |
+
self.config.settings.lcm_diffusion_setting.use_openvino
|
84 |
+
)
|
85 |
+
self.width.setCurrentText(
|
86 |
+
str(self.config.settings.lcm_diffusion_setting.image_width)
|
87 |
+
)
|
88 |
+
self.height.setCurrentText(
|
89 |
+
str(self.config.settings.lcm_diffusion_setting.image_height)
|
90 |
+
)
|
91 |
+
self.inference_steps.setValue(
|
92 |
+
int(self.config.settings.lcm_diffusion_setting.inference_steps)
|
93 |
+
)
|
94 |
+
self.seed_check.setChecked(self.config.settings.lcm_diffusion_setting.use_seed)
|
95 |
+
self.seed_value.setText(str(self.config.settings.lcm_diffusion_setting.seed))
|
96 |
+
self.use_local_model_folder.setChecked(
|
97 |
+
self.config.settings.lcm_diffusion_setting.use_offline_model
|
98 |
+
)
|
99 |
+
self.results_path.setText(self.config.settings.generated_images.path)
|
100 |
+
self.num_images.setValue(
|
101 |
+
self.config.settings.lcm_diffusion_setting.number_of_images
|
102 |
+
)
|
103 |
+
self.use_tae_sd.setChecked(
|
104 |
+
self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder
|
105 |
+
)
|
106 |
+
self.use_lcm_lora.setChecked(
|
107 |
+
self.config.settings.lcm_diffusion_setting.use_lcm_lora
|
108 |
+
)
|
109 |
+
self.lcm_model.setCurrentText(
|
110 |
+
get_valid_model_id(
|
111 |
+
self.config.lcm_models,
|
112 |
+
self.config.settings.lcm_diffusion_setting.lcm_model_id,
|
113 |
+
LCM_DEFAULT_MODEL,
|
114 |
+
)
|
115 |
+
)
|
116 |
+
self.base_model_id.setCurrentText(
|
117 |
+
get_valid_model_id(
|
118 |
+
self.config.stable_diffsuion_models,
|
119 |
+
self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id,
|
120 |
+
)
|
121 |
+
)
|
122 |
+
self.lcm_lora_id.setCurrentText(
|
123 |
+
get_valid_model_id(
|
124 |
+
self.config.lcm_lora_models,
|
125 |
+
self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id,
|
126 |
+
)
|
127 |
+
)
|
128 |
+
self.openvino_lcm_model_id.setCurrentText(
|
129 |
+
get_valid_model_id(
|
130 |
+
self.config.openvino_lcm_models,
|
131 |
+
self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id,
|
132 |
+
LCM_DEFAULT_MODEL_OPENVINO,
|
133 |
+
)
|
134 |
+
)
|
135 |
+
self.neg_prompt.setEnabled(
|
136 |
+
self.config.settings.lcm_diffusion_setting.use_lcm_lora
|
137 |
+
or self.config.settings.lcm_diffusion_setting.use_openvino
|
138 |
+
)
|
139 |
+
self.openvino_lcm_model_id.setEnabled(
|
140 |
+
self.config.settings.lcm_diffusion_setting.use_openvino
|
141 |
+
)
|
142 |
+
|
143 |
+
def init_ui(self):
|
144 |
+
self.create_main_tab()
|
145 |
+
self.create_settings_tab()
|
146 |
+
self.create_about_tab()
|
147 |
+
self.show()
|
148 |
+
|
149 |
+
def create_main_tab(self):
|
150 |
+
self.img = QLabel("<<Image>>")
|
151 |
+
self.img.setAlignment(Qt.AlignCenter)
|
152 |
+
self.img.setFixedSize(QSize(512, 512))
|
153 |
+
self.vspacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)
|
154 |
+
|
155 |
+
self.prompt = QTextEdit()
|
156 |
+
self.prompt.setPlaceholderText("A fantasy landscape")
|
157 |
+
self.prompt.setAcceptRichText(False)
|
158 |
+
self.neg_prompt = QTextEdit()
|
159 |
+
self.neg_prompt.setPlaceholderText("")
|
160 |
+
self.neg_prompt.setAcceptRichText(False)
|
161 |
+
self.neg_prompt_label = QLabel("Negative prompt (Set guidance scale > 1.0):")
|
162 |
+
self.generate = QPushButton("Generate")
|
163 |
+
self.generate.clicked.connect(self.text_to_image)
|
164 |
+
self.prompt.setFixedHeight(40)
|
165 |
+
self.neg_prompt.setFixedHeight(35)
|
166 |
+
self.browse_results = QPushButton("...")
|
167 |
+
self.browse_results.setFixedWidth(30)
|
168 |
+
self.browse_results.clicked.connect(self.on_open_results_folder)
|
169 |
+
self.browse_results.setToolTip("Open output folder")
|
170 |
+
|
171 |
+
hlayout = QHBoxLayout()
|
172 |
+
hlayout.addWidget(self.neg_prompt)
|
173 |
+
hlayout.addWidget(self.generate)
|
174 |
+
hlayout.addWidget(self.browse_results)
|
175 |
+
|
176 |
+
self.previous_img_btn = QToolButton()
|
177 |
+
self.previous_img_btn.setText("<")
|
178 |
+
self.previous_img_btn.clicked.connect(self.on_show_previous_image)
|
179 |
+
self.next_img_btn = QToolButton()
|
180 |
+
self.next_img_btn.setText(">")
|
181 |
+
self.next_img_btn.clicked.connect(self.on_show_next_image)
|
182 |
+
hlayout_nav = QHBoxLayout()
|
183 |
+
hlayout_nav.addWidget(self.previous_img_btn)
|
184 |
+
hlayout_nav.addWidget(self.img)
|
185 |
+
hlayout_nav.addWidget(self.next_img_btn)
|
186 |
+
|
187 |
+
vlayout = QVBoxLayout()
|
188 |
+
vlayout.addLayout(hlayout_nav)
|
189 |
+
vlayout.addItem(self.vspacer)
|
190 |
+
vlayout.addWidget(self.prompt)
|
191 |
+
vlayout.addWidget(self.neg_prompt_label)
|
192 |
+
vlayout.addLayout(hlayout)
|
193 |
+
|
194 |
+
self.tab_widget = QTabWidget(self)
|
195 |
+
self.tab_main = QWidget()
|
196 |
+
self.tab_settings = QWidget()
|
197 |
+
self.tab_about = QWidget()
|
198 |
+
self.tab_main.setLayout(vlayout)
|
199 |
+
|
200 |
+
self.tab_widget.addTab(self.tab_main, "Text to Image")
|
201 |
+
self.tab_widget.addTab(self.tab_settings, "Settings")
|
202 |
+
self.tab_widget.addTab(self.tab_about, "About")
|
203 |
+
|
204 |
+
self.setCentralWidget(self.tab_widget)
|
205 |
+
self.use_seed = False
|
206 |
+
|
207 |
+
def create_settings_tab(self):
|
208 |
+
self.lcm_model_label = QLabel("Latent Consistency Model:")
|
209 |
+
# self.lcm_model = QLineEdit(LCM_DEFAULT_MODEL)
|
210 |
+
self.lcm_model = QComboBox(self)
|
211 |
+
self.lcm_model.addItems(self.config.lcm_models)
|
212 |
+
self.lcm_model.currentIndexChanged.connect(self.on_lcm_model_changed)
|
213 |
+
|
214 |
+
self.use_lcm_lora = QCheckBox("Use LCM LoRA")
|
215 |
+
self.use_lcm_lora.setChecked(False)
|
216 |
+
self.use_lcm_lora.stateChanged.connect(self.use_lcm_lora_changed)
|
217 |
+
|
218 |
+
self.lora_base_model_id_label = QLabel("Lora base model ID :")
|
219 |
+
self.base_model_id = QComboBox(self)
|
220 |
+
self.base_model_id.addItems(self.config.stable_diffsuion_models)
|
221 |
+
self.base_model_id.currentIndexChanged.connect(self.on_base_model_id_changed)
|
222 |
+
|
223 |
+
self.lcm_lora_model_id_label = QLabel("LCM LoRA model ID :")
|
224 |
+
self.lcm_lora_id = QComboBox(self)
|
225 |
+
self.lcm_lora_id.addItems(self.config.lcm_lora_models)
|
226 |
+
self.lcm_lora_id.currentIndexChanged.connect(self.on_lcm_lora_id_changed)
|
227 |
+
|
228 |
+
self.inference_steps_value = QLabel("Number of inference steps: 4")
|
229 |
+
self.inference_steps = QSlider(orientation=Qt.Orientation.Horizontal)
|
230 |
+
self.inference_steps.setMaximum(25)
|
231 |
+
self.inference_steps.setMinimum(1)
|
232 |
+
self.inference_steps.setValue(4)
|
233 |
+
self.inference_steps.valueChanged.connect(self.update_steps_label)
|
234 |
+
|
235 |
+
self.num_images_value = QLabel("Number of images: 1")
|
236 |
+
self.num_images = QSlider(orientation=Qt.Orientation.Horizontal)
|
237 |
+
self.num_images.setMaximum(100)
|
238 |
+
self.num_images.setMinimum(1)
|
239 |
+
self.num_images.setValue(1)
|
240 |
+
self.num_images.valueChanged.connect(self.update_num_images_label)
|
241 |
+
|
242 |
+
self.guidance_value = QLabel("Guidance scale: 1")
|
243 |
+
self.guidance = QSlider(orientation=Qt.Orientation.Horizontal)
|
244 |
+
self.guidance.setMaximum(20)
|
245 |
+
self.guidance.setMinimum(10)
|
246 |
+
self.guidance.setValue(10)
|
247 |
+
self.guidance.valueChanged.connect(self.update_guidance_label)
|
248 |
+
|
249 |
+
self.width_value = QLabel("Width :")
|
250 |
+
self.width = QComboBox(self)
|
251 |
+
self.width.addItem("256")
|
252 |
+
self.width.addItem("512")
|
253 |
+
self.width.addItem("768")
|
254 |
+
self.width.addItem("1024")
|
255 |
+
self.width.setCurrentText("512")
|
256 |
+
self.width.currentIndexChanged.connect(self.on_width_changed)
|
257 |
+
|
258 |
+
self.height_value = QLabel("Height :")
|
259 |
+
self.height = QComboBox(self)
|
260 |
+
self.height.addItem("256")
|
261 |
+
self.height.addItem("512")
|
262 |
+
self.height.addItem("768")
|
263 |
+
self.height.addItem("1024")
|
264 |
+
self.height.setCurrentText("512")
|
265 |
+
self.height.currentIndexChanged.connect(self.on_height_changed)
|
266 |
+
|
267 |
+
self.seed_check = QCheckBox("Use seed")
|
268 |
+
self.seed_value = QLineEdit()
|
269 |
+
self.seed_value.setInputMask("9999999999")
|
270 |
+
self.seed_value.setText("123123")
|
271 |
+
self.seed_check.stateChanged.connect(self.seed_changed)
|
272 |
+
|
273 |
+
self.safety_checker = QCheckBox("Use safety checker")
|
274 |
+
self.safety_checker.setChecked(True)
|
275 |
+
self.safety_checker.stateChanged.connect(self.use_safety_checker_changed)
|
276 |
+
|
277 |
+
self.use_openvino_check = QCheckBox("Use OpenVINO")
|
278 |
+
self.use_openvino_check.setChecked(False)
|
279 |
+
self.openvino_model_label = QLabel("OpenVINO LCM model:")
|
280 |
+
self.use_local_model_folder = QCheckBox(
|
281 |
+
"Use locally cached model or downloaded model folder(offline)"
|
282 |
+
)
|
283 |
+
self.openvino_lcm_model_id = QComboBox(self)
|
284 |
+
self.openvino_lcm_model_id.addItems(self.config.openvino_lcm_models)
|
285 |
+
self.openvino_lcm_model_id.currentIndexChanged.connect(
|
286 |
+
self.on_openvino_lcm_model_id_changed
|
287 |
+
)
|
288 |
+
|
289 |
+
self.use_openvino_check.setEnabled(enable_openvino_controls())
|
290 |
+
self.use_local_model_folder.setChecked(False)
|
291 |
+
self.use_local_model_folder.stateChanged.connect(self.use_offline_model_changed)
|
292 |
+
self.use_openvino_check.stateChanged.connect(self.use_openvino_changed)
|
293 |
+
|
294 |
+
self.use_tae_sd = QCheckBox(
|
295 |
+
"Use Tiny Auto Encoder - TAESD (Fast, moderate quality)"
|
296 |
+
)
|
297 |
+
self.use_tae_sd.setChecked(False)
|
298 |
+
self.use_tae_sd.stateChanged.connect(self.use_tae_sd_changed)
|
299 |
+
|
300 |
+
hlayout = QHBoxLayout()
|
301 |
+
hlayout.addWidget(self.seed_check)
|
302 |
+
hlayout.addWidget(self.seed_value)
|
303 |
+
hspacer = QSpacerItem(20, 10, QSizePolicy.Expanding, QSizePolicy.Minimum)
|
304 |
+
slider_hspacer = QSpacerItem(20, 10, QSizePolicy.Expanding, QSizePolicy.Minimum)
|
305 |
+
|
306 |
+
self.results_path_label = QLabel("Output path:")
|
307 |
+
self.results_path = QLineEdit()
|
308 |
+
self.results_path.textChanged.connect(self.on_path_changed)
|
309 |
+
self.browse_folder_btn = QToolButton()
|
310 |
+
self.browse_folder_btn.setText("...")
|
311 |
+
self.browse_folder_btn.clicked.connect(self.on_browse_folder)
|
312 |
+
|
313 |
+
self.reset = QPushButton("Reset All")
|
314 |
+
self.reset.clicked.connect(self.reset_all_settings)
|
315 |
+
|
316 |
+
vlayout = QVBoxLayout()
|
317 |
+
vspacer = QSpacerItem(20, 20, QSizePolicy.Minimum, QSizePolicy.Expanding)
|
318 |
+
vlayout.addItem(hspacer)
|
319 |
+
vlayout.setSpacing(3)
|
320 |
+
vlayout.addWidget(self.lcm_model_label)
|
321 |
+
vlayout.addWidget(self.lcm_model)
|
322 |
+
vlayout.addWidget(self.use_local_model_folder)
|
323 |
+
vlayout.addWidget(self.use_lcm_lora)
|
324 |
+
vlayout.addWidget(self.lora_base_model_id_label)
|
325 |
+
vlayout.addWidget(self.base_model_id)
|
326 |
+
vlayout.addWidget(self.lcm_lora_model_id_label)
|
327 |
+
vlayout.addWidget(self.lcm_lora_id)
|
328 |
+
vlayout.addWidget(self.use_openvino_check)
|
329 |
+
vlayout.addWidget(self.openvino_model_label)
|
330 |
+
vlayout.addWidget(self.openvino_lcm_model_id)
|
331 |
+
vlayout.addWidget(self.use_tae_sd)
|
332 |
+
vlayout.addItem(slider_hspacer)
|
333 |
+
vlayout.addWidget(self.inference_steps_value)
|
334 |
+
vlayout.addWidget(self.inference_steps)
|
335 |
+
vlayout.addWidget(self.num_images_value)
|
336 |
+
vlayout.addWidget(self.num_images)
|
337 |
+
vlayout.addWidget(self.width_value)
|
338 |
+
vlayout.addWidget(self.width)
|
339 |
+
vlayout.addWidget(self.height_value)
|
340 |
+
vlayout.addWidget(self.height)
|
341 |
+
vlayout.addWidget(self.guidance_value)
|
342 |
+
vlayout.addWidget(self.guidance)
|
343 |
+
vlayout.addLayout(hlayout)
|
344 |
+
vlayout.addWidget(self.safety_checker)
|
345 |
+
|
346 |
+
vlayout.addWidget(self.results_path_label)
|
347 |
+
hlayout_path = QHBoxLayout()
|
348 |
+
hlayout_path.addWidget(self.results_path)
|
349 |
+
hlayout_path.addWidget(self.browse_folder_btn)
|
350 |
+
vlayout.addLayout(hlayout_path)
|
351 |
+
self.tab_settings.setLayout(vlayout)
|
352 |
+
hlayout_reset = QHBoxLayout()
|
353 |
+
hspacer = QSpacerItem(20, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
|
354 |
+
hlayout_reset.addItem(hspacer)
|
355 |
+
hlayout_reset.addWidget(self.reset)
|
356 |
+
vlayout.addLayout(hlayout_reset)
|
357 |
+
vlayout.addItem(vspacer)
|
358 |
+
|
359 |
+
def create_about_tab(self):
|
360 |
+
self.label = QLabel()
|
361 |
+
self.label.setAlignment(Qt.AlignCenter)
|
362 |
+
self.label.setText(
|
363 |
+
f"""<h1>FastSD CPU {APP_VERSION}</h1>
|
364 |
+
<h3>(c)2023 - 2024 Rupesh Sreeraman</h3>
|
365 |
+
<h3>Faster stable diffusion on CPU</h3>
|
366 |
+
<h3>Based on Latent Consistency Models</h3>
|
367 |
+
<h3>GitHub : https://github.com/rupeshs/fastsdcpu/</h3>"""
|
368 |
+
)
|
369 |
+
|
370 |
+
vlayout = QVBoxLayout()
|
371 |
+
vlayout.addWidget(self.label)
|
372 |
+
self.tab_about.setLayout(vlayout)
|
373 |
+
|
374 |
+
def show_image(self, pixmap):
|
375 |
+
image_width = self.config.settings.lcm_diffusion_setting.image_width
|
376 |
+
image_height = self.config.settings.lcm_diffusion_setting.image_height
|
377 |
+
if image_width > 512 or image_height > 512:
|
378 |
+
new_width = 512 if image_width > 512 else image_width
|
379 |
+
new_height = 512 if image_height > 512 else image_height
|
380 |
+
self.img.setPixmap(
|
381 |
+
pixmap.scaled(
|
382 |
+
new_width,
|
383 |
+
new_height,
|
384 |
+
Qt.KeepAspectRatio,
|
385 |
+
)
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
self.img.setPixmap(pixmap)
|
389 |
+
|
390 |
+
def on_show_next_image(self):
|
391 |
+
if self.image_index != len(self.gen_images) - 1 and len(self.gen_images) > 0:
|
392 |
+
self.previous_img_btn.setEnabled(True)
|
393 |
+
self.image_index += 1
|
394 |
+
self.show_image(self.gen_images[self.image_index])
|
395 |
+
if self.image_index == len(self.gen_images) - 1:
|
396 |
+
self.next_img_btn.setEnabled(False)
|
397 |
+
|
398 |
+
def on_open_results_folder(self):
|
399 |
+
QDesktopServices.openUrl(
|
400 |
+
QUrl.fromLocalFile(self.config.settings.generated_images.path)
|
401 |
+
)
|
402 |
+
|
403 |
+
def on_show_previous_image(self):
|
404 |
+
if self.image_index != 0:
|
405 |
+
self.next_img_btn.setEnabled(True)
|
406 |
+
self.image_index -= 1
|
407 |
+
self.show_image(self.gen_images[self.image_index])
|
408 |
+
if self.image_index == 0:
|
409 |
+
self.previous_img_btn.setEnabled(False)
|
410 |
+
|
411 |
+
def on_path_changed(self, text):
|
412 |
+
self.config.settings.generated_images.path = text
|
413 |
+
|
414 |
+
def on_browse_folder(self):
|
415 |
+
options = QFileDialog.Options()
|
416 |
+
options |= QFileDialog.ShowDirsOnly
|
417 |
+
|
418 |
+
folder_path = QFileDialog.getExistingDirectory(
|
419 |
+
self, "Select a Folder", "", options=options
|
420 |
+
)
|
421 |
+
|
422 |
+
if folder_path:
|
423 |
+
self.config.settings.generated_images.path = folder_path
|
424 |
+
self.results_path.setText(folder_path)
|
425 |
+
|
426 |
+
def on_width_changed(self, index):
|
427 |
+
width_txt = self.width.itemText(index)
|
428 |
+
self.config.settings.lcm_diffusion_setting.image_width = int(width_txt)
|
429 |
+
|
430 |
+
def on_height_changed(self, index):
|
431 |
+
height_txt = self.height.itemText(index)
|
432 |
+
self.config.settings.lcm_diffusion_setting.image_height = int(height_txt)
|
433 |
+
|
434 |
+
def on_lcm_model_changed(self, index):
|
435 |
+
model_id = self.lcm_model.itemText(index)
|
436 |
+
self.config.settings.lcm_diffusion_setting.lcm_model_id = model_id
|
437 |
+
|
438 |
+
def on_base_model_id_changed(self, index):
|
439 |
+
model_id = self.base_model_id.itemText(index)
|
440 |
+
self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id = model_id
|
441 |
+
|
442 |
+
def on_lcm_lora_id_changed(self, index):
|
443 |
+
model_id = self.lcm_lora_id.itemText(index)
|
444 |
+
self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = model_id
|
445 |
+
|
446 |
+
def on_openvino_lcm_model_id_changed(self, index):
|
447 |
+
model_id = self.openvino_lcm_model_id.itemText(index)
|
448 |
+
self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
|
449 |
+
|
450 |
+
def use_openvino_changed(self, state):
|
451 |
+
if state == 2:
|
452 |
+
self.lcm_model.setEnabled(False)
|
453 |
+
self.use_lcm_lora.setEnabled(False)
|
454 |
+
self.lcm_lora_id.setEnabled(False)
|
455 |
+
self.base_model_id.setEnabled(False)
|
456 |
+
self.neg_prompt.setEnabled(True)
|
457 |
+
self.openvino_lcm_model_id.setEnabled(True)
|
458 |
+
self.config.settings.lcm_diffusion_setting.use_openvino = True
|
459 |
+
else:
|
460 |
+
self.lcm_model.setEnabled(True)
|
461 |
+
self.use_lcm_lora.setEnabled(True)
|
462 |
+
self.lcm_lora_id.setEnabled(True)
|
463 |
+
self.base_model_id.setEnabled(True)
|
464 |
+
self.neg_prompt.setEnabled(False)
|
465 |
+
self.openvino_lcm_model_id.setEnabled(False)
|
466 |
+
self.config.settings.lcm_diffusion_setting.use_openvino = False
|
467 |
+
|
468 |
+
def use_tae_sd_changed(self, state):
|
469 |
+
if state == 2:
|
470 |
+
self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder = True
|
471 |
+
else:
|
472 |
+
self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder = False
|
473 |
+
|
474 |
+
def use_offline_model_changed(self, state):
|
475 |
+
if state == 2:
|
476 |
+
self.config.settings.lcm_diffusion_setting.use_offline_model = True
|
477 |
+
else:
|
478 |
+
self.config.settings.lcm_diffusion_setting.use_offline_model = False
|
479 |
+
|
480 |
+
def use_lcm_lora_changed(self, state):
|
481 |
+
if state == 2:
|
482 |
+
self.lcm_model.setEnabled(False)
|
483 |
+
self.lcm_lora_id.setEnabled(True)
|
484 |
+
self.base_model_id.setEnabled(True)
|
485 |
+
self.neg_prompt.setEnabled(True)
|
486 |
+
self.config.settings.lcm_diffusion_setting.use_lcm_lora = True
|
487 |
+
else:
|
488 |
+
self.lcm_model.setEnabled(True)
|
489 |
+
self.lcm_lora_id.setEnabled(False)
|
490 |
+
self.base_model_id.setEnabled(False)
|
491 |
+
self.neg_prompt.setEnabled(False)
|
492 |
+
self.config.settings.lcm_diffusion_setting.use_lcm_lora = False
|
493 |
+
|
494 |
+
def use_safety_checker_changed(self, state):
|
495 |
+
if state == 2:
|
496 |
+
self.config.settings.lcm_diffusion_setting.use_safety_checker = True
|
497 |
+
else:
|
498 |
+
self.config.settings.lcm_diffusion_setting.use_safety_checker = False
|
499 |
+
|
500 |
+
def update_steps_label(self, value):
|
501 |
+
self.inference_steps_value.setText(f"Number of inference steps: {value}")
|
502 |
+
self.config.settings.lcm_diffusion_setting.inference_steps = value
|
503 |
+
|
504 |
+
def update_num_images_label(self, value):
|
505 |
+
self.num_images_value.setText(f"Number of images: {value}")
|
506 |
+
self.config.settings.lcm_diffusion_setting.number_of_images = value
|
507 |
+
|
508 |
+
def update_guidance_label(self, value):
|
509 |
+
val = round(int(value) / 10, 1)
|
510 |
+
self.guidance_value.setText(f"Guidance scale: {val}")
|
511 |
+
self.config.settings.lcm_diffusion_setting.guidance_scale = val
|
512 |
+
|
513 |
+
def seed_changed(self, state):
|
514 |
+
if state == 2:
|
515 |
+
self.seed_value.setEnabled(True)
|
516 |
+
self.config.settings.lcm_diffusion_setting.use_seed = True
|
517 |
+
else:
|
518 |
+
self.seed_value.setEnabled(False)
|
519 |
+
self.config.settings.lcm_diffusion_setting.use_seed = False
|
520 |
+
|
521 |
+
def get_seed_value(self) -> int:
|
522 |
+
use_seed = self.config.settings.lcm_diffusion_setting.use_seed
|
523 |
+
seed_value = int(self.seed_value.text()) if use_seed else -1
|
524 |
+
return seed_value
|
525 |
+
|
526 |
+
def generate_image(self):
|
527 |
+
self.config.settings.lcm_diffusion_setting.seed = self.get_seed_value()
|
528 |
+
self.config.settings.lcm_diffusion_setting.prompt = self.prompt.toPlainText()
|
529 |
+
self.config.settings.lcm_diffusion_setting.negative_prompt = (
|
530 |
+
self.neg_prompt.toPlainText()
|
531 |
+
)
|
532 |
+
self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = (
|
533 |
+
self.lcm_lora_id.currentText()
|
534 |
+
)
|
535 |
+
self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id = (
|
536 |
+
self.base_model_id.currentText()
|
537 |
+
)
|
538 |
+
|
539 |
+
if self.config.settings.lcm_diffusion_setting.use_openvino:
|
540 |
+
model_id = self.openvino_lcm_model_id.currentText()
|
541 |
+
self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
|
542 |
+
else:
|
543 |
+
model_id = self.lcm_model.currentText()
|
544 |
+
self.config.settings.lcm_diffusion_setting.lcm_model_id = model_id
|
545 |
+
|
546 |
+
reshape_required = False
|
547 |
+
if self.config.settings.lcm_diffusion_setting.use_openvino:
|
548 |
+
# Detect dimension change
|
549 |
+
reshape_required = is_reshape_required(
|
550 |
+
self.previous_width,
|
551 |
+
self.config.settings.lcm_diffusion_setting.image_width,
|
552 |
+
self.previous_height,
|
553 |
+
self.config.settings.lcm_diffusion_setting.image_height,
|
554 |
+
self.previous_model,
|
555 |
+
model_id,
|
556 |
+
self.previous_num_of_images,
|
557 |
+
self.config.settings.lcm_diffusion_setting.number_of_images,
|
558 |
+
)
|
559 |
+
self.config.settings.lcm_diffusion_setting.diffusion_task = (
|
560 |
+
DiffusionTask.text_to_image.value
|
561 |
+
)
|
562 |
+
images = self.context.generate_text_to_image(
|
563 |
+
self.config.settings,
|
564 |
+
reshape_required,
|
565 |
+
DEVICE,
|
566 |
+
)
|
567 |
+
self.image_index = 0
|
568 |
+
self.gen_images = []
|
569 |
+
for img in images:
|
570 |
+
im = ImageQt(img).copy()
|
571 |
+
pixmap = QPixmap.fromImage(im)
|
572 |
+
self.gen_images.append(pixmap)
|
573 |
+
|
574 |
+
if len(self.gen_images) > 1:
|
575 |
+
self.next_img_btn.setEnabled(True)
|
576 |
+
self.previous_img_btn.setEnabled(False)
|
577 |
+
else:
|
578 |
+
self.next_img_btn.setEnabled(False)
|
579 |
+
self.previous_img_btn.setEnabled(False)
|
580 |
+
|
581 |
+
self.show_image(self.gen_images[0])
|
582 |
+
|
583 |
+
self.previous_width = self.config.settings.lcm_diffusion_setting.image_width
|
584 |
+
self.previous_height = self.config.settings.lcm_diffusion_setting.image_height
|
585 |
+
self.previous_model = model_id
|
586 |
+
self.previous_num_of_images = (
|
587 |
+
self.config.settings.lcm_diffusion_setting.number_of_images
|
588 |
+
)
|
589 |
+
|
590 |
+
def text_to_image(self):
|
591 |
+
self.img.setText("Please wait...")
|
592 |
+
worker = ImageGeneratorWorker(self.generate_image)
|
593 |
+
self.threadpool.start(worker)
|
594 |
+
|
595 |
+
def closeEvent(self, event):
|
596 |
+
self.config.settings.lcm_diffusion_setting.seed = self.get_seed_value()
|
597 |
+
print(self.config.settings.lcm_diffusion_setting)
|
598 |
+
print("Saving settings")
|
599 |
+
self.config.save()
|
600 |
+
|
601 |
+
def reset_all_settings(self):
|
602 |
+
self.use_local_model_folder.setChecked(False)
|
603 |
+
self.width.setCurrentText("512")
|
604 |
+
self.height.setCurrentText("512")
|
605 |
+
self.inference_steps.setValue(4)
|
606 |
+
self.guidance.setValue(10)
|
607 |
+
self.use_openvino_check.setChecked(False)
|
608 |
+
self.seed_check.setChecked(False)
|
609 |
+
self.safety_checker.setChecked(False)
|
610 |
+
self.results_path.setText(FastStableDiffusionPaths().get_results_path())
|
611 |
+
self.use_tae_sd.setChecked(False)
|
612 |
+
self.use_lcm_lora.setChecked(False)
|
frontend/gui/image_generator_worker.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PyQt5.QtCore import (
|
2 |
+
pyqtSlot,
|
3 |
+
QRunnable,
|
4 |
+
pyqtSignal,
|
5 |
+
pyqtSlot,
|
6 |
+
)
|
7 |
+
from PyQt5.QtCore import QObject
|
8 |
+
import traceback
|
9 |
+
import sys
|
10 |
+
|
11 |
+
|
12 |
+
class WorkerSignals(QObject):
|
13 |
+
finished = pyqtSignal()
|
14 |
+
error = pyqtSignal(tuple)
|
15 |
+
result = pyqtSignal(object)
|
16 |
+
|
17 |
+
|
18 |
+
class ImageGeneratorWorker(QRunnable):
|
19 |
+
def __init__(self, fn, *args, **kwargs):
|
20 |
+
super(ImageGeneratorWorker, self).__init__()
|
21 |
+
self.fn = fn
|
22 |
+
self.args = args
|
23 |
+
self.kwargs = kwargs
|
24 |
+
self.signals = WorkerSignals()
|
25 |
+
|
26 |
+
@pyqtSlot()
|
27 |
+
def run(self):
|
28 |
+
try:
|
29 |
+
result = self.fn(*self.args, **self.kwargs)
|
30 |
+
except:
|
31 |
+
traceback.print_exc()
|
32 |
+
exctype, value = sys.exc_info()[:2]
|
33 |
+
self.signals.error.emit((exctype, value, traceback.format_exc()))
|
34 |
+
else:
|
35 |
+
self.signals.result.emit(result)
|
36 |
+
finally:
|
37 |
+
self.signals.finished.emit()
|
frontend/gui/ui.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from frontend.gui.app_window import MainWindow
|
3 |
+
from PyQt5.QtWidgets import QApplication
|
4 |
+
import sys
|
5 |
+
from app_settings import AppSettings
|
6 |
+
|
7 |
+
|
8 |
+
def start_gui(
|
9 |
+
argv: List[str],
|
10 |
+
app_settings: AppSettings,
|
11 |
+
):
|
12 |
+
app = QApplication(sys.argv)
|
13 |
+
window = MainWindow(app_settings)
|
14 |
+
window.show()
|
15 |
+
app.exec()
|
frontend/utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
from os import path
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from backend.device import is_openvino_device
|
6 |
+
from constants import DEVICE
|
7 |
+
from paths import get_file_name
|
8 |
+
|
9 |
+
|
10 |
+
def is_reshape_required(
|
11 |
+
prev_width: int,
|
12 |
+
cur_width: int,
|
13 |
+
prev_height: int,
|
14 |
+
cur_height: int,
|
15 |
+
prev_model: int,
|
16 |
+
cur_model: int,
|
17 |
+
prev_num_of_images: int,
|
18 |
+
cur_num_of_images: int,
|
19 |
+
) -> bool:
|
20 |
+
reshape_required = False
|
21 |
+
if (
|
22 |
+
prev_width != cur_width
|
23 |
+
or prev_height != cur_height
|
24 |
+
or prev_model != cur_model
|
25 |
+
or prev_num_of_images != cur_num_of_images
|
26 |
+
):
|
27 |
+
print("Reshape and compile")
|
28 |
+
reshape_required = True
|
29 |
+
|
30 |
+
return reshape_required
|
31 |
+
|
32 |
+
|
33 |
+
def enable_openvino_controls() -> bool:
|
34 |
+
return is_openvino_device() and platform.system().lower() != "darwin" and platform.processor().lower() != 'arm'
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def get_valid_model_id(
|
39 |
+
models: List,
|
40 |
+
model_id: str,
|
41 |
+
default_model: str = "",
|
42 |
+
) -> str:
|
43 |
+
if len(models) == 0:
|
44 |
+
print("Error: model configuration file is empty,please add some models.")
|
45 |
+
return ""
|
46 |
+
if model_id == "":
|
47 |
+
if default_model:
|
48 |
+
return default_model
|
49 |
+
else:
|
50 |
+
return models[0]
|
51 |
+
|
52 |
+
if model_id in models:
|
53 |
+
return model_id
|
54 |
+
else:
|
55 |
+
print(
|
56 |
+
f"Error:{model_id} Model not found in configuration file,so using first model : {models[0]}"
|
57 |
+
)
|
58 |
+
return models[0]
|
59 |
+
|
60 |
+
|
61 |
+
def get_valid_lora_model(
|
62 |
+
models: List,
|
63 |
+
cur_model: str,
|
64 |
+
lora_models_dir: str,
|
65 |
+
) -> str:
|
66 |
+
if cur_model == "" or cur_model is None:
|
67 |
+
print(
|
68 |
+
f"No lora models found, please add lora models to {lora_models_dir} directory"
|
69 |
+
)
|
70 |
+
return ""
|
71 |
+
else:
|
72 |
+
if path.exists(cur_model):
|
73 |
+
return get_file_name(cur_model)
|
74 |
+
else:
|
75 |
+
print(f"Lora model {cur_model} not found")
|
76 |
+
if len(models) > 0:
|
77 |
+
print(f"Fallback model - {models[0]}")
|
78 |
+
return get_file_name(models[0])
|
79 |
+
else:
|
80 |
+
print(
|
81 |
+
f"No lora models found, please add lora models to {lora_models_dir} directory"
|
82 |
+
)
|
83 |
+
return ""
|
frontend/webui/controlnet_ui.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
from backend.lora import get_lora_models
|
4 |
+
from state import get_settings
|
5 |
+
from backend.models.lcmdiffusion_setting import ControlNetSetting
|
6 |
+
from backend.annotators.image_control_factory import ImageControlFactory
|
7 |
+
|
8 |
+
_controlnet_models_map = None
|
9 |
+
_controlnet_enabled = False
|
10 |
+
_adapter_path = None
|
11 |
+
|
12 |
+
app_settings = get_settings()
|
13 |
+
|
14 |
+
|
15 |
+
def on_user_input(
|
16 |
+
enable: bool,
|
17 |
+
adapter_name: str,
|
18 |
+
conditioning_scale: float,
|
19 |
+
control_image: Image,
|
20 |
+
preprocessor: str,
|
21 |
+
):
|
22 |
+
if not isinstance(adapter_name, str):
|
23 |
+
gr.Warning("Please select a valid ControlNet model")
|
24 |
+
return gr.Checkbox(value=False)
|
25 |
+
|
26 |
+
settings = app_settings.settings.lcm_diffusion_setting
|
27 |
+
if settings.controlnet is None:
|
28 |
+
settings.controlnet = ControlNetSetting()
|
29 |
+
|
30 |
+
if enable and (adapter_name is None or adapter_name == ""):
|
31 |
+
gr.Warning("Please select a valid ControlNet adapter")
|
32 |
+
return gr.Checkbox(value=False)
|
33 |
+
elif enable and not control_image:
|
34 |
+
gr.Warning("Please provide a ControlNet control image")
|
35 |
+
return gr.Checkbox(value=False)
|
36 |
+
|
37 |
+
if control_image is None:
|
38 |
+
return gr.Checkbox(value=enable)
|
39 |
+
|
40 |
+
if preprocessor == "None":
|
41 |
+
processed_control_image = control_image
|
42 |
+
else:
|
43 |
+
image_control_factory = ImageControlFactory()
|
44 |
+
control = image_control_factory.create_control(preprocessor)
|
45 |
+
processed_control_image = control.get_control_image(control_image)
|
46 |
+
|
47 |
+
if not enable:
|
48 |
+
settings.controlnet.enabled = False
|
49 |
+
else:
|
50 |
+
settings.controlnet.enabled = True
|
51 |
+
settings.controlnet.adapter_path = _controlnet_models_map[adapter_name]
|
52 |
+
settings.controlnet.conditioning_scale = float(conditioning_scale)
|
53 |
+
settings.controlnet._control_image = processed_control_image
|
54 |
+
|
55 |
+
# This code can be improved; currently, if the user clicks the
|
56 |
+
# "Enable ControlNet" checkbox or changes the currently selected
|
57 |
+
# ControlNet model, it will trigger a pipeline rebuild even if, in
|
58 |
+
# the end, the user leaves the same ControlNet settings
|
59 |
+
global _controlnet_enabled
|
60 |
+
global _adapter_path
|
61 |
+
if settings.controlnet.enabled != _controlnet_enabled or (
|
62 |
+
settings.controlnet.enabled
|
63 |
+
and settings.controlnet.adapter_path != _adapter_path
|
64 |
+
):
|
65 |
+
settings.rebuild_pipeline = True
|
66 |
+
_controlnet_enabled = settings.controlnet.enabled
|
67 |
+
_adapter_path = settings.controlnet.adapter_path
|
68 |
+
return gr.Checkbox(value=enable)
|
69 |
+
|
70 |
+
|
71 |
+
def on_change_conditioning_scale(cond_scale):
|
72 |
+
print(cond_scale)
|
73 |
+
app_settings.settings.lcm_diffusion_setting.controlnet.conditioning_scale = (
|
74 |
+
cond_scale
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def get_controlnet_ui() -> None:
|
79 |
+
with gr.Blocks() as ui:
|
80 |
+
gr.HTML(
|
81 |
+
'Download ControlNet v1.1 model from <a href="https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/tree/main">ControlNet v1.1 </a> (723 MB files) and place it in <b>controlnet_models</b> folder,restart the app'
|
82 |
+
)
|
83 |
+
with gr.Row():
|
84 |
+
with gr.Column():
|
85 |
+
with gr.Row():
|
86 |
+
global _controlnet_models_map
|
87 |
+
_controlnet_models_map = get_lora_models(
|
88 |
+
app_settings.settings.lcm_diffusion_setting.dirs["controlnet"]
|
89 |
+
)
|
90 |
+
controlnet_models = list(_controlnet_models_map.keys())
|
91 |
+
default_model = (
|
92 |
+
controlnet_models[0] if len(controlnet_models) else None
|
93 |
+
)
|
94 |
+
|
95 |
+
enabled_checkbox = gr.Checkbox(
|
96 |
+
label="Enable ControlNet",
|
97 |
+
info="Enable ControlNet",
|
98 |
+
show_label=True,
|
99 |
+
)
|
100 |
+
model_dropdown = gr.Dropdown(
|
101 |
+
_controlnet_models_map.keys(),
|
102 |
+
label="ControlNet model",
|
103 |
+
info="ControlNet model to load (.safetensors format)",
|
104 |
+
value=default_model,
|
105 |
+
interactive=True,
|
106 |
+
)
|
107 |
+
conditioning_scale_slider = gr.Slider(
|
108 |
+
0.0,
|
109 |
+
1.0,
|
110 |
+
value=0.5,
|
111 |
+
step=0.05,
|
112 |
+
label="ControlNet conditioning scale",
|
113 |
+
interactive=True,
|
114 |
+
)
|
115 |
+
control_image = gr.Image(
|
116 |
+
label="Control image",
|
117 |
+
type="pil",
|
118 |
+
)
|
119 |
+
preprocessor_radio = gr.Radio(
|
120 |
+
[
|
121 |
+
"Canny",
|
122 |
+
"Depth",
|
123 |
+
"LineArt",
|
124 |
+
"MLSD",
|
125 |
+
"NormalBAE",
|
126 |
+
"Pose",
|
127 |
+
"SoftEdge",
|
128 |
+
"Shuffle",
|
129 |
+
"None",
|
130 |
+
],
|
131 |
+
label="Preprocessor",
|
132 |
+
info="Select the preprocessor for the control image",
|
133 |
+
value="Canny",
|
134 |
+
interactive=True,
|
135 |
+
)
|
136 |
+
|
137 |
+
enabled_checkbox.input(
|
138 |
+
fn=on_user_input,
|
139 |
+
inputs=[
|
140 |
+
enabled_checkbox,
|
141 |
+
model_dropdown,
|
142 |
+
conditioning_scale_slider,
|
143 |
+
control_image,
|
144 |
+
preprocessor_radio,
|
145 |
+
],
|
146 |
+
outputs=[enabled_checkbox],
|
147 |
+
)
|
148 |
+
model_dropdown.input(
|
149 |
+
fn=on_user_input,
|
150 |
+
inputs=[
|
151 |
+
enabled_checkbox,
|
152 |
+
model_dropdown,
|
153 |
+
conditioning_scale_slider,
|
154 |
+
control_image,
|
155 |
+
preprocessor_radio,
|
156 |
+
],
|
157 |
+
outputs=[enabled_checkbox],
|
158 |
+
)
|
159 |
+
conditioning_scale_slider.input(
|
160 |
+
fn=on_user_input,
|
161 |
+
inputs=[
|
162 |
+
enabled_checkbox,
|
163 |
+
model_dropdown,
|
164 |
+
conditioning_scale_slider,
|
165 |
+
control_image,
|
166 |
+
preprocessor_radio,
|
167 |
+
],
|
168 |
+
outputs=[enabled_checkbox],
|
169 |
+
)
|
170 |
+
control_image.change(
|
171 |
+
fn=on_user_input,
|
172 |
+
inputs=[
|
173 |
+
enabled_checkbox,
|
174 |
+
model_dropdown,
|
175 |
+
conditioning_scale_slider,
|
176 |
+
control_image,
|
177 |
+
preprocessor_radio,
|
178 |
+
],
|
179 |
+
outputs=[enabled_checkbox],
|
180 |
+
)
|
181 |
+
preprocessor_radio.change(
|
182 |
+
fn=on_user_input,
|
183 |
+
inputs=[
|
184 |
+
enabled_checkbox,
|
185 |
+
model_dropdown,
|
186 |
+
conditioning_scale_slider,
|
187 |
+
control_image,
|
188 |
+
preprocessor_radio,
|
189 |
+
],
|
190 |
+
outputs=[enabled_checkbox],
|
191 |
+
)
|
192 |
+
conditioning_scale_slider.change(
|
193 |
+
on_change_conditioning_scale, conditioning_scale_slider
|
194 |
+
)
|
frontend/webui/css/style.css
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
footer {
|
2 |
+
visibility: hidden
|
3 |
+
}
|
4 |
+
|
5 |
+
#generate_button {
|
6 |
+
color: white;
|
7 |
+
border-color: #007bff;
|
8 |
+
background: #2563eb;
|
9 |
+
|
10 |
+
}
|
11 |
+
|
12 |
+
#save_button {
|
13 |
+
color: white;
|
14 |
+
border-color: #028b40;
|
15 |
+
background: #01b97c;
|
16 |
+
width: 200px;
|
17 |
+
}
|
18 |
+
|
19 |
+
#settings_header {
|
20 |
+
background: rgb(245, 105, 105);
|
21 |
+
|
22 |
+
}
|
frontend/webui/generation_settings_ui.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from state import get_settings
|
3 |
+
from backend.models.gen_images import ImageFormat
|
4 |
+
|
5 |
+
app_settings = get_settings()
|
6 |
+
|
7 |
+
|
8 |
+
def on_change_inference_steps(steps):
|
9 |
+
app_settings.settings.lcm_diffusion_setting.inference_steps = steps
|
10 |
+
|
11 |
+
|
12 |
+
def on_change_image_width(img_width):
|
13 |
+
app_settings.settings.lcm_diffusion_setting.image_width = img_width
|
14 |
+
|
15 |
+
|
16 |
+
def on_change_image_height(img_height):
|
17 |
+
app_settings.settings.lcm_diffusion_setting.image_height = img_height
|
18 |
+
|
19 |
+
|
20 |
+
def on_change_num_images(num_images):
|
21 |
+
app_settings.settings.lcm_diffusion_setting.number_of_images = num_images
|
22 |
+
|
23 |
+
|
24 |
+
def on_change_guidance_scale(guidance_scale):
|
25 |
+
app_settings.settings.lcm_diffusion_setting.guidance_scale = guidance_scale
|
26 |
+
|
27 |
+
|
28 |
+
def on_change_seed_value(seed):
|
29 |
+
app_settings.settings.lcm_diffusion_setting.seed = seed
|
30 |
+
|
31 |
+
|
32 |
+
def on_change_seed_checkbox(seed_checkbox):
|
33 |
+
app_settings.settings.lcm_diffusion_setting.use_seed = seed_checkbox
|
34 |
+
|
35 |
+
|
36 |
+
def on_change_safety_checker_checkbox(safety_checker_checkbox):
|
37 |
+
app_settings.settings.lcm_diffusion_setting.use_safety_checker = (
|
38 |
+
safety_checker_checkbox
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def on_change_tiny_auto_encoder_checkbox(tiny_auto_encoder_checkbox):
|
43 |
+
app_settings.settings.lcm_diffusion_setting.use_tiny_auto_encoder = (
|
44 |
+
tiny_auto_encoder_checkbox
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def on_offline_checkbox(offline_checkbox):
|
49 |
+
app_settings.settings.lcm_diffusion_setting.use_offline_model = offline_checkbox
|
50 |
+
|
51 |
+
|
52 |
+
def on_change_image_format(image_format):
|
53 |
+
if image_format == "PNG":
|
54 |
+
app_settings.settings.generated_images.format = ImageFormat.PNG.value.upper()
|
55 |
+
else:
|
56 |
+
app_settings.settings.generated_images.format = ImageFormat.JPEG.value.upper()
|
57 |
+
|
58 |
+
app_settings.save()
|
59 |
+
|
60 |
+
|
61 |
+
def get_generation_settings_ui() -> None:
|
62 |
+
with gr.Blocks():
|
63 |
+
with gr.Row():
|
64 |
+
with gr.Column():
|
65 |
+
num_inference_steps = gr.Slider(
|
66 |
+
1,
|
67 |
+
25,
|
68 |
+
value=app_settings.settings.lcm_diffusion_setting.inference_steps,
|
69 |
+
step=1,
|
70 |
+
label="Inference Steps",
|
71 |
+
interactive=True,
|
72 |
+
)
|
73 |
+
|
74 |
+
image_height = gr.Slider(
|
75 |
+
256,
|
76 |
+
1024,
|
77 |
+
value=app_settings.settings.lcm_diffusion_setting.image_height,
|
78 |
+
step=256,
|
79 |
+
label="Image Height",
|
80 |
+
interactive=True,
|
81 |
+
)
|
82 |
+
image_width = gr.Slider(
|
83 |
+
256,
|
84 |
+
1024,
|
85 |
+
value=app_settings.settings.lcm_diffusion_setting.image_width,
|
86 |
+
step=256,
|
87 |
+
label="Image Width",
|
88 |
+
interactive=True,
|
89 |
+
)
|
90 |
+
num_images = gr.Slider(
|
91 |
+
1,
|
92 |
+
50,
|
93 |
+
value=app_settings.settings.lcm_diffusion_setting.number_of_images,
|
94 |
+
step=1,
|
95 |
+
label="Number of images to generate",
|
96 |
+
interactive=True,
|
97 |
+
)
|
98 |
+
guidance_scale = gr.Slider(
|
99 |
+
1.0,
|
100 |
+
10.0,
|
101 |
+
value=app_settings.settings.lcm_diffusion_setting.guidance_scale,
|
102 |
+
step=0.1,
|
103 |
+
label="Guidance Scale",
|
104 |
+
interactive=True,
|
105 |
+
)
|
106 |
+
|
107 |
+
seed = gr.Slider(
|
108 |
+
value=app_settings.settings.lcm_diffusion_setting.seed,
|
109 |
+
minimum=0,
|
110 |
+
maximum=999999999,
|
111 |
+
label="Seed",
|
112 |
+
step=1,
|
113 |
+
interactive=True,
|
114 |
+
)
|
115 |
+
seed_checkbox = gr.Checkbox(
|
116 |
+
label="Use seed",
|
117 |
+
value=app_settings.settings.lcm_diffusion_setting.use_seed,
|
118 |
+
interactive=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
safety_checker_checkbox = gr.Checkbox(
|
122 |
+
label="Use Safety Checker",
|
123 |
+
value=app_settings.settings.lcm_diffusion_setting.use_safety_checker,
|
124 |
+
interactive=True,
|
125 |
+
)
|
126 |
+
tiny_auto_encoder_checkbox = gr.Checkbox(
|
127 |
+
label="Use tiny auto encoder for SD",
|
128 |
+
value=app_settings.settings.lcm_diffusion_setting.use_tiny_auto_encoder,
|
129 |
+
interactive=True,
|
130 |
+
)
|
131 |
+
offline_checkbox = gr.Checkbox(
|
132 |
+
label="Use locally cached model or downloaded model folder(offline)",
|
133 |
+
value=app_settings.settings.lcm_diffusion_setting.use_offline_model,
|
134 |
+
interactive=True,
|
135 |
+
)
|
136 |
+
img_format = gr.Radio(
|
137 |
+
label="Output image format",
|
138 |
+
choices=["PNG", "JPEG"],
|
139 |
+
value=app_settings.settings.generated_images.format,
|
140 |
+
interactive=True,
|
141 |
+
)
|
142 |
+
|
143 |
+
num_inference_steps.change(on_change_inference_steps, num_inference_steps)
|
144 |
+
image_height.change(on_change_image_height, image_height)
|
145 |
+
image_width.change(on_change_image_width, image_width)
|
146 |
+
num_images.change(on_change_num_images, num_images)
|
147 |
+
guidance_scale.change(on_change_guidance_scale, guidance_scale)
|
148 |
+
seed.change(on_change_seed_value, seed)
|
149 |
+
seed_checkbox.change(on_change_seed_checkbox, seed_checkbox)
|
150 |
+
safety_checker_checkbox.change(
|
151 |
+
on_change_safety_checker_checkbox, safety_checker_checkbox
|
152 |
+
)
|
153 |
+
tiny_auto_encoder_checkbox.change(
|
154 |
+
on_change_tiny_auto_encoder_checkbox, tiny_auto_encoder_checkbox
|
155 |
+
)
|
156 |
+
offline_checkbox.change(on_offline_checkbox, offline_checkbox)
|
157 |
+
img_format.change(on_change_image_format, img_format)
|
frontend/webui/image_to_image_ui.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
import gradio as gr
|
3 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
4 |
+
from models.interface_types import InterfaceType
|
5 |
+
from frontend.utils import is_reshape_required
|
6 |
+
from constants import DEVICE
|
7 |
+
from state import get_settings, get_context
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
|
10 |
+
|
11 |
+
app_settings = get_settings()
|
12 |
+
|
13 |
+
previous_width = 0
|
14 |
+
previous_height = 0
|
15 |
+
previous_model_id = ""
|
16 |
+
previous_num_of_images = 0
|
17 |
+
|
18 |
+
|
19 |
+
def generate_image_to_image(
|
20 |
+
prompt,
|
21 |
+
negative_prompt,
|
22 |
+
init_image,
|
23 |
+
strength,
|
24 |
+
) -> Any:
|
25 |
+
context = get_context(InterfaceType.WEBUI)
|
26 |
+
global previous_height, previous_width, previous_model_id, previous_num_of_images, app_settings
|
27 |
+
|
28 |
+
app_settings.settings.lcm_diffusion_setting.prompt = prompt
|
29 |
+
app_settings.settings.lcm_diffusion_setting.negative_prompt = negative_prompt
|
30 |
+
app_settings.settings.lcm_diffusion_setting.init_image = init_image
|
31 |
+
app_settings.settings.lcm_diffusion_setting.strength = strength
|
32 |
+
|
33 |
+
app_settings.settings.lcm_diffusion_setting.diffusion_task = (
|
34 |
+
DiffusionTask.image_to_image.value
|
35 |
+
)
|
36 |
+
model_id = app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id
|
37 |
+
reshape = False
|
38 |
+
image_width = app_settings.settings.lcm_diffusion_setting.image_width
|
39 |
+
image_height = app_settings.settings.lcm_diffusion_setting.image_height
|
40 |
+
num_images = app_settings.settings.lcm_diffusion_setting.number_of_images
|
41 |
+
if app_settings.settings.lcm_diffusion_setting.use_openvino:
|
42 |
+
reshape = is_reshape_required(
|
43 |
+
previous_width,
|
44 |
+
image_width,
|
45 |
+
previous_height,
|
46 |
+
image_height,
|
47 |
+
previous_model_id,
|
48 |
+
model_id,
|
49 |
+
previous_num_of_images,
|
50 |
+
num_images,
|
51 |
+
)
|
52 |
+
|
53 |
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
54 |
+
future = executor.submit(
|
55 |
+
context.generate_text_to_image,
|
56 |
+
app_settings.settings,
|
57 |
+
reshape,
|
58 |
+
DEVICE,
|
59 |
+
)
|
60 |
+
images = future.result()
|
61 |
+
|
62 |
+
previous_width = image_width
|
63 |
+
previous_height = image_height
|
64 |
+
previous_model_id = model_id
|
65 |
+
previous_num_of_images = num_images
|
66 |
+
return images
|
67 |
+
|
68 |
+
|
69 |
+
def get_image_to_image_ui() -> None:
|
70 |
+
with gr.Blocks():
|
71 |
+
with gr.Row():
|
72 |
+
with gr.Column():
|
73 |
+
input_image = gr.Image(label="Init image", type="pil")
|
74 |
+
with gr.Row():
|
75 |
+
prompt = gr.Textbox(
|
76 |
+
show_label=False,
|
77 |
+
lines=3,
|
78 |
+
placeholder="A fantasy landscape",
|
79 |
+
container=False,
|
80 |
+
)
|
81 |
+
|
82 |
+
generate_btn = gr.Button(
|
83 |
+
"Generate",
|
84 |
+
elem_id="generate_button",
|
85 |
+
scale=0,
|
86 |
+
)
|
87 |
+
negative_prompt = gr.Textbox(
|
88 |
+
label="Negative prompt (Works in LCM-LoRA mode, set guidance > 1.0):",
|
89 |
+
lines=1,
|
90 |
+
placeholder="",
|
91 |
+
)
|
92 |
+
strength = gr.Slider(
|
93 |
+
0.1,
|
94 |
+
1,
|
95 |
+
value=app_settings.settings.lcm_diffusion_setting.strength,
|
96 |
+
step=0.01,
|
97 |
+
label="Strength",
|
98 |
+
)
|
99 |
+
|
100 |
+
input_params = [
|
101 |
+
prompt,
|
102 |
+
negative_prompt,
|
103 |
+
input_image,
|
104 |
+
strength,
|
105 |
+
]
|
106 |
+
|
107 |
+
with gr.Column():
|
108 |
+
output = gr.Gallery(
|
109 |
+
label="Generated images",
|
110 |
+
show_label=True,
|
111 |
+
elem_id="gallery",
|
112 |
+
columns=2,
|
113 |
+
height=512,
|
114 |
+
)
|
115 |
+
|
116 |
+
generate_btn.click(
|
117 |
+
fn=generate_image_to_image,
|
118 |
+
inputs=input_params,
|
119 |
+
outputs=output,
|
120 |
+
)
|
frontend/webui/image_variations_ui.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
import gradio as gr
|
3 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
4 |
+
from context import Context
|
5 |
+
from models.interface_types import InterfaceType
|
6 |
+
from frontend.utils import is_reshape_required
|
7 |
+
from constants import DEVICE
|
8 |
+
from state import get_settings, get_context
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
|
11 |
+
app_settings = get_settings()
|
12 |
+
|
13 |
+
|
14 |
+
previous_width = 0
|
15 |
+
previous_height = 0
|
16 |
+
previous_model_id = ""
|
17 |
+
previous_num_of_images = 0
|
18 |
+
|
19 |
+
|
20 |
+
def generate_image_variations(
|
21 |
+
init_image,
|
22 |
+
variation_strength,
|
23 |
+
) -> Any:
|
24 |
+
context = get_context(InterfaceType.WEBUI)
|
25 |
+
global previous_height, previous_width, previous_model_id, previous_num_of_images, app_settings
|
26 |
+
|
27 |
+
app_settings.settings.lcm_diffusion_setting.init_image = init_image
|
28 |
+
app_settings.settings.lcm_diffusion_setting.strength = variation_strength
|
29 |
+
app_settings.settings.lcm_diffusion_setting.prompt = ""
|
30 |
+
app_settings.settings.lcm_diffusion_setting.negative_prompt = ""
|
31 |
+
|
32 |
+
app_settings.settings.lcm_diffusion_setting.diffusion_task = (
|
33 |
+
DiffusionTask.image_to_image.value
|
34 |
+
)
|
35 |
+
model_id = app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id
|
36 |
+
reshape = False
|
37 |
+
image_width = app_settings.settings.lcm_diffusion_setting.image_width
|
38 |
+
image_height = app_settings.settings.lcm_diffusion_setting.image_height
|
39 |
+
num_images = app_settings.settings.lcm_diffusion_setting.number_of_images
|
40 |
+
if app_settings.settings.lcm_diffusion_setting.use_openvino:
|
41 |
+
reshape = is_reshape_required(
|
42 |
+
previous_width,
|
43 |
+
image_width,
|
44 |
+
previous_height,
|
45 |
+
image_height,
|
46 |
+
previous_model_id,
|
47 |
+
model_id,
|
48 |
+
previous_num_of_images,
|
49 |
+
num_images,
|
50 |
+
)
|
51 |
+
|
52 |
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
53 |
+
future = executor.submit(
|
54 |
+
context.generate_text_to_image,
|
55 |
+
app_settings.settings,
|
56 |
+
reshape,
|
57 |
+
DEVICE,
|
58 |
+
)
|
59 |
+
images = future.result()
|
60 |
+
|
61 |
+
previous_width = image_width
|
62 |
+
previous_height = image_height
|
63 |
+
previous_model_id = model_id
|
64 |
+
previous_num_of_images = num_images
|
65 |
+
return images
|
66 |
+
|
67 |
+
|
68 |
+
def get_image_variations_ui() -> None:
|
69 |
+
with gr.Blocks():
|
70 |
+
with gr.Row():
|
71 |
+
with gr.Column():
|
72 |
+
input_image = gr.Image(label="Init image", type="pil")
|
73 |
+
with gr.Row():
|
74 |
+
generate_btn = gr.Button(
|
75 |
+
"Generate",
|
76 |
+
elem_id="generate_button",
|
77 |
+
scale=0,
|
78 |
+
)
|
79 |
+
|
80 |
+
variation_strength = gr.Slider(
|
81 |
+
0.1,
|
82 |
+
1,
|
83 |
+
value=0.4,
|
84 |
+
step=0.01,
|
85 |
+
label="Variations Strength",
|
86 |
+
)
|
87 |
+
|
88 |
+
input_params = [
|
89 |
+
input_image,
|
90 |
+
variation_strength,
|
91 |
+
]
|
92 |
+
|
93 |
+
with gr.Column():
|
94 |
+
output = gr.Gallery(
|
95 |
+
label="Generated images",
|
96 |
+
show_label=True,
|
97 |
+
elem_id="gallery",
|
98 |
+
columns=2,
|
99 |
+
height=512,
|
100 |
+
)
|
101 |
+
|
102 |
+
generate_btn.click(
|
103 |
+
fn=generate_image_variations,
|
104 |
+
inputs=input_params,
|
105 |
+
outputs=output,
|
106 |
+
)
|
frontend/webui/lora_models_ui.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from os import path
|
3 |
+
from backend.lora import (
|
4 |
+
get_lora_models,
|
5 |
+
get_active_lora_weights,
|
6 |
+
update_lora_weights,
|
7 |
+
load_lora_weight,
|
8 |
+
)
|
9 |
+
from state import get_settings, get_context
|
10 |
+
from frontend.utils import get_valid_lora_model
|
11 |
+
from models.interface_types import InterfaceType
|
12 |
+
from backend.models.lcmdiffusion_setting import LCMDiffusionSetting
|
13 |
+
|
14 |
+
|
15 |
+
_MAX_LORA_WEIGHTS = 5
|
16 |
+
|
17 |
+
_custom_lora_sliders = []
|
18 |
+
_custom_lora_names = []
|
19 |
+
_custom_lora_columns = []
|
20 |
+
|
21 |
+
app_settings = get_settings()
|
22 |
+
|
23 |
+
|
24 |
+
def on_click_update_weight(*lora_weights):
|
25 |
+
update_weights = []
|
26 |
+
active_weights = get_active_lora_weights()
|
27 |
+
if not len(active_weights):
|
28 |
+
gr.Warning("No active LoRAs, first you need to load LoRA model")
|
29 |
+
return
|
30 |
+
for idx, lora in enumerate(active_weights):
|
31 |
+
update_weights.append(
|
32 |
+
(
|
33 |
+
lora[0],
|
34 |
+
lora_weights[idx],
|
35 |
+
)
|
36 |
+
)
|
37 |
+
if len(update_weights) > 0:
|
38 |
+
update_lora_weights(
|
39 |
+
get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline,
|
40 |
+
app_settings.settings.lcm_diffusion_setting,
|
41 |
+
update_weights,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def on_click_load_lora(lora_name, lora_weight):
|
46 |
+
if app_settings.settings.lcm_diffusion_setting.use_openvino:
|
47 |
+
gr.Warning("Currently LoRA is not supported in OpenVINO.")
|
48 |
+
return
|
49 |
+
lora_models_map = get_lora_models(
|
50 |
+
app_settings.settings.lcm_diffusion_setting.lora.models_dir
|
51 |
+
)
|
52 |
+
|
53 |
+
# Load a new LoRA
|
54 |
+
settings = app_settings.settings.lcm_diffusion_setting
|
55 |
+
settings.lora.fuse = False
|
56 |
+
settings.lora.enabled = False
|
57 |
+
settings.lora.path = lora_models_map[lora_name]
|
58 |
+
settings.lora.weight = lora_weight
|
59 |
+
if not path.exists(settings.lora.path):
|
60 |
+
gr.Warning("Invalid LoRA model path!")
|
61 |
+
return
|
62 |
+
pipeline = get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline
|
63 |
+
if not pipeline:
|
64 |
+
gr.Warning("Pipeline not initialized. Please generate an image first.")
|
65 |
+
return
|
66 |
+
settings.lora.enabled = True
|
67 |
+
load_lora_weight(
|
68 |
+
get_context(InterfaceType.WEBUI).lcm_text_to_image.pipeline,
|
69 |
+
settings,
|
70 |
+
)
|
71 |
+
|
72 |
+
# Update Gradio LoRA UI
|
73 |
+
global _MAX_LORA_WEIGHTS
|
74 |
+
values = []
|
75 |
+
labels = []
|
76 |
+
rows = []
|
77 |
+
active_weights = get_active_lora_weights()
|
78 |
+
for idx, lora in enumerate(active_weights):
|
79 |
+
labels.append(f"{lora[0]}: ")
|
80 |
+
values.append(lora[1])
|
81 |
+
rows.append(gr.Row.update(visible=True))
|
82 |
+
for i in range(len(active_weights), _MAX_LORA_WEIGHTS):
|
83 |
+
labels.append(f"Update weight")
|
84 |
+
values.append(0.0)
|
85 |
+
rows.append(gr.Row.update(visible=False))
|
86 |
+
return labels + values + rows
|
87 |
+
|
88 |
+
|
89 |
+
def get_lora_models_ui() -> None:
|
90 |
+
with gr.Blocks() as ui:
|
91 |
+
gr.HTML(
|
92 |
+
"Download and place your LoRA model weights in <b>lora_models</b> folders and restart App"
|
93 |
+
)
|
94 |
+
with gr.Row():
|
95 |
+
|
96 |
+
with gr.Column():
|
97 |
+
with gr.Row():
|
98 |
+
lora_models_map = get_lora_models(
|
99 |
+
app_settings.settings.lcm_diffusion_setting.lora.models_dir
|
100 |
+
)
|
101 |
+
valid_model = get_valid_lora_model(
|
102 |
+
list(lora_models_map.values()),
|
103 |
+
app_settings.settings.lcm_diffusion_setting.lora.path,
|
104 |
+
app_settings.settings.lcm_diffusion_setting.lora.models_dir,
|
105 |
+
)
|
106 |
+
if valid_model != "":
|
107 |
+
valid_model_path = lora_models_map[valid_model]
|
108 |
+
app_settings.settings.lcm_diffusion_setting.lora.path = (
|
109 |
+
valid_model_path
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
app_settings.settings.lcm_diffusion_setting.lora.path = ""
|
113 |
+
|
114 |
+
lora_model = gr.Dropdown(
|
115 |
+
lora_models_map.keys(),
|
116 |
+
label="LoRA model",
|
117 |
+
info="LoRA model weight to load (You can use Lora models from Civitai or Hugging Face .safetensors format)",
|
118 |
+
value=valid_model,
|
119 |
+
interactive=True,
|
120 |
+
)
|
121 |
+
|
122 |
+
lora_weight = gr.Slider(
|
123 |
+
0.0,
|
124 |
+
1.0,
|
125 |
+
value=app_settings.settings.lcm_diffusion_setting.lora.weight,
|
126 |
+
step=0.05,
|
127 |
+
label="Initial Lora weight",
|
128 |
+
interactive=True,
|
129 |
+
)
|
130 |
+
load_lora_btn = gr.Button(
|
131 |
+
"Load selected LoRA",
|
132 |
+
elem_id="load_lora_button",
|
133 |
+
scale=0,
|
134 |
+
)
|
135 |
+
|
136 |
+
with gr.Row():
|
137 |
+
gr.Markdown(
|
138 |
+
"## Loaded LoRA models",
|
139 |
+
show_label=False,
|
140 |
+
)
|
141 |
+
update_lora_weights_btn = gr.Button(
|
142 |
+
"Update LoRA weights",
|
143 |
+
elem_id="load_lora_button",
|
144 |
+
scale=0,
|
145 |
+
)
|
146 |
+
|
147 |
+
global _MAX_LORA_WEIGHTS
|
148 |
+
global _custom_lora_sliders
|
149 |
+
global _custom_lora_names
|
150 |
+
global _custom_lora_columns
|
151 |
+
for i in range(0, _MAX_LORA_WEIGHTS):
|
152 |
+
new_row = gr.Column(visible=False)
|
153 |
+
_custom_lora_columns.append(new_row)
|
154 |
+
with new_row:
|
155 |
+
lora_name = gr.Markdown(
|
156 |
+
"Lora Name",
|
157 |
+
show_label=True,
|
158 |
+
)
|
159 |
+
lora_slider = gr.Slider(
|
160 |
+
0.0,
|
161 |
+
1.0,
|
162 |
+
step=0.05,
|
163 |
+
label="LoRA weight",
|
164 |
+
interactive=True,
|
165 |
+
visible=True,
|
166 |
+
)
|
167 |
+
|
168 |
+
_custom_lora_names.append(lora_name)
|
169 |
+
_custom_lora_sliders.append(lora_slider)
|
170 |
+
|
171 |
+
load_lora_btn.click(
|
172 |
+
fn=on_click_load_lora,
|
173 |
+
inputs=[lora_model, lora_weight],
|
174 |
+
outputs=[
|
175 |
+
*_custom_lora_names,
|
176 |
+
*_custom_lora_sliders,
|
177 |
+
*_custom_lora_columns,
|
178 |
+
],
|
179 |
+
)
|
180 |
+
|
181 |
+
update_lora_weights_btn.click(
|
182 |
+
fn=on_click_update_weight,
|
183 |
+
inputs=[*_custom_lora_sliders],
|
184 |
+
outputs=None,
|
185 |
+
)
|
frontend/webui/models_ui.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app_settings import AppSettings
|
2 |
+
from typing import Any
|
3 |
+
import gradio as gr
|
4 |
+
from constants import LCM_DEFAULT_MODEL, LCM_DEFAULT_MODEL_OPENVINO
|
5 |
+
from state import get_settings
|
6 |
+
from frontend.utils import get_valid_model_id
|
7 |
+
|
8 |
+
app_settings = get_settings()
|
9 |
+
app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id = get_valid_model_id(
|
10 |
+
app_settings.openvino_lcm_models,
|
11 |
+
app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def change_lcm_model_id(model_id):
|
16 |
+
app_settings.settings.lcm_diffusion_setting.lcm_model_id = model_id
|
17 |
+
|
18 |
+
|
19 |
+
def change_lcm_lora_model_id(model_id):
|
20 |
+
app_settings.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = model_id
|
21 |
+
|
22 |
+
|
23 |
+
def change_lcm_lora_base_model_id(model_id):
|
24 |
+
app_settings.settings.lcm_diffusion_setting.lcm_lora.base_model_id = model_id
|
25 |
+
|
26 |
+
|
27 |
+
def change_openvino_lcm_model_id(model_id):
|
28 |
+
app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
|
29 |
+
|
30 |
+
|
31 |
+
def get_models_ui() -> None:
|
32 |
+
with gr.Blocks():
|
33 |
+
with gr.Row():
|
34 |
+
lcm_model_id = gr.Dropdown(
|
35 |
+
app_settings.lcm_models,
|
36 |
+
label="LCM model",
|
37 |
+
info="Diffusers LCM model ID",
|
38 |
+
value=get_valid_model_id(
|
39 |
+
app_settings.lcm_models,
|
40 |
+
app_settings.settings.lcm_diffusion_setting.lcm_model_id,
|
41 |
+
LCM_DEFAULT_MODEL,
|
42 |
+
),
|
43 |
+
interactive=True,
|
44 |
+
)
|
45 |
+
with gr.Row():
|
46 |
+
lcm_lora_model_id = gr.Dropdown(
|
47 |
+
app_settings.lcm_lora_models,
|
48 |
+
label="LCM LoRA model",
|
49 |
+
info="Diffusers LCM LoRA model ID",
|
50 |
+
value=get_valid_model_id(
|
51 |
+
app_settings.lcm_lora_models,
|
52 |
+
app_settings.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id,
|
53 |
+
),
|
54 |
+
interactive=True,
|
55 |
+
)
|
56 |
+
lcm_lora_base_model_id = gr.Dropdown(
|
57 |
+
app_settings.stable_diffsuion_models,
|
58 |
+
label="LCM LoRA base model",
|
59 |
+
info="Diffusers LCM LoRA base model ID",
|
60 |
+
value=get_valid_model_id(
|
61 |
+
app_settings.stable_diffsuion_models,
|
62 |
+
app_settings.settings.lcm_diffusion_setting.lcm_lora.base_model_id,
|
63 |
+
),
|
64 |
+
interactive=True,
|
65 |
+
)
|
66 |
+
with gr.Row():
|
67 |
+
lcm_openvino_model_id = gr.Dropdown(
|
68 |
+
app_settings.openvino_lcm_models,
|
69 |
+
label="LCM OpenVINO model",
|
70 |
+
info="OpenVINO LCM-LoRA fused model ID",
|
71 |
+
value=get_valid_model_id(
|
72 |
+
app_settings.openvino_lcm_models,
|
73 |
+
app_settings.settings.lcm_diffusion_setting.openvino_lcm_model_id,
|
74 |
+
),
|
75 |
+
interactive=True,
|
76 |
+
)
|
77 |
+
|
78 |
+
lcm_model_id.change(change_lcm_model_id, lcm_model_id)
|
79 |
+
lcm_lora_model_id.change(change_lcm_lora_model_id, lcm_lora_model_id)
|
80 |
+
lcm_lora_base_model_id.change(
|
81 |
+
change_lcm_lora_base_model_id, lcm_lora_base_model_id
|
82 |
+
)
|
83 |
+
lcm_openvino_model_id.change(
|
84 |
+
change_openvino_lcm_model_id, lcm_openvino_model_id
|
85 |
+
)
|