AlekseyCalvin commited on
Commit
9b31b59
1 Parent(s): a2aad52

Upload 2 files

Browse files
Files changed (2) hide show
  1. mod.py +360 -0
  2. modutils.py +1290 -0
mod.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ import gc
7
+ import subprocess
8
+ from env import num_cns, model_trigger
9
+
10
+
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+ subprocess.run('pip cache purge', shell=True)
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ torch.set_grad_enabled(False)
15
+
16
+
17
+ control_images = [None] * num_cns
18
+ control_modes = [-1] * num_cns
19
+ control_scales = [0] * num_cns
20
+
21
+
22
+ def is_repo_name(s):
23
+ import re
24
+ return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
25
+
26
+
27
+ def is_repo_exists(repo_id):
28
+ from huggingface_hub import HfApi
29
+ api = HfApi()
30
+ try:
31
+ if api.repo_exists(repo_id=repo_id): return True
32
+ else: return False
33
+ except Exception as e:
34
+ print(f"Error: Failed to connect {repo_id}.")
35
+ print(e)
36
+ return True # for safe
37
+
38
+
39
+ from translatepy import Translator
40
+ translator = Translator()
41
+ def translate_to_en(input: str):
42
+ try:
43
+ output = str(translator.translate(input, 'English'))
44
+ except Exception as e:
45
+ output = input
46
+ print(e)
47
+ return output
48
+
49
+
50
+ def clear_cache():
51
+ try:
52
+ torch.cuda.empty_cache()
53
+ #torch.cuda.reset_max_memory_allocated()
54
+ #torch.cuda.reset_peak_memory_stats()
55
+ gc.collect()
56
+ except Exception as e:
57
+ print(e)
58
+ raise Exception(f"Cache clearing error: {e}") from e
59
+
60
+
61
+ def get_repo_safetensors(repo_id: str):
62
+ from huggingface_hub import HfApi
63
+ api = HfApi()
64
+ try:
65
+ if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
66
+ files = api.list_repo_files(repo_id=repo_id)
67
+ except Exception as e:
68
+ print(f"Error: Failed to get {repo_id}'s info.")
69
+ print(e)
70
+ gr.Warning(f"Error: Failed to get {repo_id}'s info.")
71
+ return gr.update(choices=[])
72
+ files = [f for f in files if f.endswith(".safetensors")]
73
+ if len(files) == 0: return gr.update(value="", choices=[])
74
+ else: return gr.update(value=files[0], choices=files)
75
+
76
+
77
+ def expand2square(pil_img: Image.Image, background_color: tuple=(0, 0, 0)):
78
+ width, height = pil_img.size
79
+ if width == height:
80
+ return pil_img
81
+ elif width > height:
82
+ result = Image.new(pil_img.mode, (width, width), background_color)
83
+ result.paste(pil_img, (0, (width - height) // 2))
84
+ return result
85
+ else:
86
+ result = Image.new(pil_img.mode, (height, height), background_color)
87
+ result.paste(pil_img, ((height - width) // 2, 0))
88
+ return result
89
+
90
+
91
+ # https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny/blob/main/app.py
92
+ def resize_image(image, target_width, target_height, crop=True):
93
+ from image_datasets.canny_dataset import c_crop
94
+ if crop:
95
+ image = c_crop(image) # Crop the image to square
96
+ original_width, original_height = image.size
97
+
98
+ # Resize to match the target size without stretching
99
+ scale = max(target_width / original_width, target_height / original_height)
100
+ resized_width = int(scale * original_width)
101
+ resized_height = int(scale * original_height)
102
+
103
+ image = image.resize((resized_width, resized_height), Image.LANCZOS)
104
+
105
+ # Center crop to match the target dimensions
106
+ left = (resized_width - target_width) // 2
107
+ top = (resized_height - target_height) // 2
108
+ image = image.crop((left, top, left + target_width, top + target_height))
109
+ else:
110
+ image = image.resize((target_width, target_height), Image.LANCZOS)
111
+
112
+ return image
113
+
114
+
115
+ # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union/blob/main/app.py
116
+ # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
117
+ controlnet_union_modes = {
118
+ "None": -1,
119
+ #"scribble_hed": 0,
120
+ "canny": 0, # supported
121
+ "mlsd": 0, #supported
122
+ "tile": 1, #supported
123
+ "depth_midas": 2, # supported
124
+ "blur": 3, # supported
125
+ "openpose": 4, # supported
126
+ "gray": 5, # supported
127
+ "low_quality": 6, # supported
128
+ }
129
+
130
+
131
+ # https://github.com/pytorch/pytorch/issues/123834
132
+ def get_control_params():
133
+ from diffusers.utils import load_image
134
+ modes = []
135
+ images = []
136
+ scales = []
137
+ for i, mode in enumerate(control_modes):
138
+ if mode == -1 or control_images[i] is None: continue
139
+ modes.append(control_modes[i])
140
+ images.append(load_image(control_images[i]))
141
+ scales.append(control_scales[i])
142
+ return modes, images, scales
143
+
144
+
145
+ from preprocessor import Preprocessor
146
+ def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int,
147
+ preprocess_resolution: int):
148
+ if control_mode == "None": return image
149
+ image_resolution = max(width, height)
150
+ image_before = resize_image(expand2square(image.convert("RGB")), image_resolution, image_resolution, False)
151
+ # generated control_
152
+ print("start to generate control image")
153
+ preprocessor = Preprocessor()
154
+ if control_mode == "depth_midas":
155
+ preprocessor.load("Midas")
156
+ control_image = preprocessor(
157
+ image=image_before,
158
+ image_resolution=image_resolution,
159
+ detect_resolution=preprocess_resolution,
160
+ )
161
+ if control_mode == "openpose":
162
+ preprocessor.load("Openpose")
163
+ control_image = preprocessor(
164
+ image=image_before,
165
+ hand_and_face=True,
166
+ image_resolution=image_resolution,
167
+ detect_resolution=preprocess_resolution,
168
+ )
169
+ if control_mode == "canny":
170
+ preprocessor.load("Canny")
171
+ control_image = preprocessor(
172
+ image=image_before,
173
+ image_resolution=image_resolution,
174
+ detect_resolution=preprocess_resolution,
175
+ )
176
+
177
+ if control_mode == "mlsd":
178
+ preprocessor.load("MLSD")
179
+ control_image = preprocessor(
180
+ image=image_before,
181
+ image_resolution=image_resolution,
182
+ detect_resolution=preprocess_resolution,
183
+ )
184
+
185
+ if control_mode == "scribble_hed":
186
+ preprocessor.load("HED")
187
+ control_image = preprocessor(
188
+ image=image_before,
189
+ image_resolution=image_resolution,
190
+ detect_resolution=preprocess_resolution,
191
+ )
192
+
193
+ if control_mode == "low_quality" or control_mode == "gray" or control_mode == "blur" or control_mode == "tile":
194
+ control_image = image_before
195
+ image_width = 768
196
+ image_height = 768
197
+ else:
198
+ # make sure control image size is same as resized_image
199
+ image_width, image_height = control_image.size
200
+
201
+ image_after = resize_image(control_image, width, height, False)
202
+ ref_width, ref_height = image.size
203
+ print(f"generate control image success: {ref_width}x{ref_height} => {image_width}x{image_height}")
204
+ return image_after
205
+
206
+
207
+ def get_control_union_mode():
208
+ return list(controlnet_union_modes.keys())
209
+
210
+
211
+ def set_control_union_mode(i: int, mode: str, scale: str):
212
+ global control_modes
213
+ global control_scales
214
+ control_modes[i] = controlnet_union_modes.get(mode, 0)
215
+ control_scales[i] = scale
216
+ if mode != "None": return True
217
+ else: return gr.update(visible=True)
218
+
219
+
220
+ def set_control_union_image(i: int, mode: str, image: Image.Image | None, height: int, width: int, preprocess_resolution: int):
221
+ global control_images
222
+ if image is None: return None
223
+ control_images[i] = preprocess_image(image, mode, height, width, preprocess_resolution)
224
+ return control_images[i]
225
+
226
+
227
+ def preprocess_i2i_image(image_path: str, is_preprocess: bool, height: int, width: int):
228
+ try:
229
+ if not is_preprocess: return image_path
230
+ image_resolution = max(width, height)
231
+ image = Image.open(image_path)
232
+ image_resized = resize_image(expand2square(image.convert("RGB")), image_resolution, image_resolution, False)
233
+ image_resized.save(image_path)
234
+ except Exception as e:
235
+ raise gr.Error(f"Error: {e}")
236
+ return image_path
237
+
238
+
239
+ def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):
240
+ lorajson[i]["name"] = str(name) if name != "None" else ""
241
+ lorajson[i]["scale"] = float(scale)
242
+ lorajson[i]["filename"] = str(filename)
243
+ lorajson[i]["trigger"] = str(trigger)
244
+ return lorajson
245
+
246
+
247
+ def is_valid_lora(lorajson: list[dict]):
248
+ valid = False
249
+ for d in lorajson:
250
+ if "name" in d.keys() and d["name"] and d["name"] != "None": valid = True
251
+ return valid
252
+
253
+
254
+ def get_trigger_word(lorajson: list[dict]):
255
+ trigger = ""
256
+ for d in lorajson:
257
+ if "name" in d.keys() and d["name"] and d["name"] != "None" and d["trigger"]:
258
+ trigger += ", " + d["trigger"]
259
+ return trigger
260
+
261
+
262
+ def get_model_trigger(model_name: str):
263
+ trigger = ""
264
+ if model_name in model_trigger.keys(): trigger += ", " + model_trigger[model_name]
265
+ return trigger
266
+
267
+
268
+ # https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora
269
+ # https://github.com/huggingface/diffusers/issues/4919
270
+ def fuse_loras(pipe, lorajson: list[dict]):
271
+ try:
272
+ if not lorajson or not isinstance(lorajson, list): return pipe, [], []
273
+ a_list = []
274
+ w_list = []
275
+ for d in lorajson:
276
+ if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue
277
+ k = d["name"]
278
+ if is_repo_name(k) and is_repo_exists(k):
279
+ a_name = Path(k).stem
280
+ pipe.load_lora_weights(k, weight_name=d["filename"], adapter_name = a_name, low_cpu_mem_usage=True)
281
+ elif not Path(k).exists():
282
+ print(f"LoRA not found: {k}")
283
+ continue
284
+ else:
285
+ w_name = Path(k).name
286
+ a_name = Path(k).stem
287
+ pipe.load_lora_weights(k, weight_name = w_name, adapter_name = a_name, low_cpu_mem_usage=True)
288
+ a_list.append(a_name)
289
+ w_list.append(d["scale"])
290
+ if not a_list: return pipe, [], []
291
+ #pipe.set_adapters(a_list, adapter_weights=w_list)
292
+ #pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
293
+ #pipe.unload_lora_weights()
294
+ return pipe, a_list, w_list
295
+ except Exception as e:
296
+ print(f"External LoRA Error: {e}")
297
+ raise Exception(f"External LoRA Error: {e}") from e
298
+
299
+
300
+ def description_ui():
301
+ gr.Markdown(
302
+ """
303
+ - Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
304
+ [multimodalart/flux-lora-lab](https://huggingface.co/spaces/multimodalart/flux-lora-lab),
305
+ [jiuface/FLUX.1-dev-Controlnet-Union](https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union),
306
+ [DamarJati/FLUX.1-DEV-Canny](https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny),
307
+ [gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
308
+ """
309
+ )
310
+
311
+
312
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
313
+ def load_prompt_enhancer():
314
+ try:
315
+ model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
316
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
317
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
318
+ enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
319
+ except Exception as e:
320
+ print(e)
321
+ enhancer_flux = None
322
+ return enhancer_flux
323
+
324
+
325
+ enhancer_flux = load_prompt_enhancer()
326
+
327
+
328
+ @spaces.GPU(duration=30)
329
+ def enhance_prompt(input_prompt):
330
+ result = enhancer_flux("enhance prompt: " + translate_to_en(input_prompt), max_length = 256)
331
+ enhanced_text = result[0]['generated_text']
332
+ return enhanced_text
333
+
334
+
335
+ def save_image(image, savefile, modelname, prompt, height, width, steps, cfg, seed):
336
+ import uuid
337
+ from PIL import PngImagePlugin
338
+ import json
339
+ try:
340
+ if savefile is None: savefile = f"{modelname.split('/')[-1]}_{str(uuid.uuid4())}.png"
341
+ metadata = {"prompt": prompt, "Model": {"Model": modelname.split("/")[-1]}}
342
+ metadata["num_inference_steps"] = steps
343
+ metadata["guidance_scale"] = cfg
344
+ metadata["seed"] = seed
345
+ metadata["resolution"] = f"{width} x {height}"
346
+ metadata_str = json.dumps(metadata)
347
+ info = PngImagePlugin.PngInfo()
348
+ info.add_text("metadata", metadata_str)
349
+ image.save(savefile, "PNG", pnginfo=info)
350
+ return str(Path(savefile).resolve())
351
+ except Exception as e:
352
+ print(f"Failed to save image file: {e}")
353
+ raise Exception(f"Failed to save image file:") from e
354
+
355
+
356
+ load_prompt_enhancer.zerogpu = True
357
+ fuse_loras.zerogpu = True
358
+ preprocess_image.zerogpu = True
359
+ get_control_params.zerogpu = True
360
+ clear_cache.zerogpu = True
modutils.py ADDED
@@ -0,0 +1,1290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import gradio as gr
4
+ from huggingface_hub import HfApi
5
+ import os
6
+ from pathlib import Path
7
+ from PIL import Image
8
+
9
+
10
+ from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
11
+ HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
12
+ directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
13
+
14
+
15
+ MODEL_TYPE_DICT = {
16
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
17
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
18
+ "diffusers:FluxPipeline": "FLUX",
19
+ }
20
+
21
+
22
+ def get_user_agent():
23
+ return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
24
+
25
+
26
+ def to_list(s):
27
+ return [x.strip() for x in s.split(",") if not s == ""]
28
+
29
+
30
+ def list_uniq(l):
31
+ return sorted(set(l), key=l.index)
32
+
33
+
34
+ def list_sub(a, b):
35
+ return [e for e in a if e not in b]
36
+
37
+
38
+ def is_repo_name(s):
39
+ import re
40
+ return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
41
+
42
+
43
+ from translatepy import Translator
44
+ translator = Translator()
45
+ def translate_to_en(input: str):
46
+ try:
47
+ output = str(translator.translate(input, 'English'))
48
+ except Exception as e:
49
+ output = input
50
+ print(e)
51
+ return output
52
+
53
+
54
+ def get_local_model_list(dir_path):
55
+ model_list = []
56
+ valid_extensions = ('.ckpt', '.pt', '.pth', '.safetensors', '.bin')
57
+ for file in Path(dir_path).glob("*"):
58
+ if file.suffix in valid_extensions:
59
+ file_path = str(Path(f"{dir_path}/{file.name}"))
60
+ model_list.append(file_path)
61
+ return model_list
62
+
63
+
64
+ def download_things(directory, url, hf_token="", civitai_api_key=""):
65
+ url = url.strip()
66
+ if "drive.google.com" in url:
67
+ original_dir = os.getcwd()
68
+ os.chdir(directory)
69
+ os.system(f"gdown --fuzzy {url}")
70
+ os.chdir(original_dir)
71
+ elif "huggingface.co" in url:
72
+ url = url.replace("?download=true", "")
73
+ # url = urllib.parse.quote(url, safe=':/') # fix encoding
74
+ if "/blob/" in url:
75
+ url = url.replace("/blob/", "/resolve/")
76
+ user_header = f'"Authorization: Bearer {hf_token}"'
77
+ if hf_token:
78
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
79
+ else:
80
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
81
+ elif "civitai.com" in url:
82
+ if "?" in url:
83
+ url = url.split("?")[0]
84
+ if civitai_api_key:
85
+ url = url + f"?token={civitai_api_key}"
86
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
87
+ else:
88
+ print("\033[91mYou need an API key to download Civitai models.\033[0m")
89
+ else:
90
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
91
+
92
+
93
+ def escape_lora_basename(basename: str):
94
+ return basename.replace(".", "_").replace(" ", "_").replace(",", "")
95
+
96
+
97
+ def to_lora_key(path: str):
98
+ return escape_lora_basename(Path(path).stem)
99
+
100
+
101
+ def to_lora_path(key: str):
102
+ if Path(key).is_file(): return key
103
+ path = Path(f"{directory_loras}/{escape_lora_basename(key)}.safetensors")
104
+ return str(path)
105
+
106
+
107
+ def safe_float(input):
108
+ output = 1.0
109
+ try:
110
+ output = float(input)
111
+ except Exception:
112
+ output = 1.0
113
+ return output
114
+
115
+
116
+ def save_images(images: list[Image.Image], metadatas: list[str]):
117
+ from PIL import PngImagePlugin
118
+ import uuid
119
+ try:
120
+ output_images = []
121
+ for image, metadata in zip(images, metadatas):
122
+ info = PngImagePlugin.PngInfo()
123
+ info.add_text("parameters", metadata)
124
+ savefile = f"{str(uuid.uuid4())}.png"
125
+ image.save(savefile, "PNG", pnginfo=info)
126
+ output_images.append(str(Path(savefile).resolve()))
127
+ return output_images
128
+ except Exception as e:
129
+ print(f"Failed to save image file: {e}")
130
+ raise Exception(f"Failed to save image file:") from e
131
+
132
+
133
+ def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
134
+ from datetime import datetime, timezone, timedelta
135
+ progress(0, desc="Updating gallery...")
136
+ dt_now = datetime.now(timezone(timedelta(hours=9)))
137
+ basename = dt_now.strftime('%Y%m%d_%H%M%S_')
138
+ i = 1
139
+ if not images: return images, gr.update(visible=False)
140
+ output_images = []
141
+ output_paths = []
142
+ for image in images:
143
+ filename = basename + str(i) + ".png"
144
+ i += 1
145
+ oldpath = Path(image[0])
146
+ newpath = oldpath
147
+ try:
148
+ if oldpath.exists():
149
+ newpath = oldpath.resolve().rename(Path(filename).resolve())
150
+ except Exception as e:
151
+ print(e)
152
+ finally:
153
+ output_paths.append(str(newpath))
154
+ output_images.append((str(newpath), str(filename)))
155
+ progress(1, desc="Gallery updated.")
156
+ return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
157
+
158
+
159
+ def download_private_repo(repo_id, dir_path, is_replace):
160
+ from huggingface_hub import snapshot_download
161
+ if not hf_read_token: return
162
+ try:
163
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], use_auth_token=hf_read_token)
164
+ except Exception as e:
165
+ print(f"Error: Failed to download {repo_id}.")
166
+ print(e)
167
+ return
168
+ if is_replace:
169
+ for file in Path(dir_path).glob("*"):
170
+ if file.exists() and "." in file.stem or " " in file.stem and file.suffix in ['.ckpt', '.pt', '.pth', '.safetensors', '.bin']:
171
+ newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}')
172
+ file.resolve().rename(newpath.resolve())
173
+
174
+
175
+ private_model_path_repo_dict = {} # {"local filepath": "huggingface repo_id", ...}
176
+
177
+
178
+ def get_private_model_list(repo_id, dir_path):
179
+ global private_model_path_repo_dict
180
+ api = HfApi()
181
+ if not hf_read_token: return []
182
+ try:
183
+ files = api.list_repo_files(repo_id, token=hf_read_token)
184
+ except Exception as e:
185
+ print(f"Error: Failed to list {repo_id}.")
186
+ print(e)
187
+ return []
188
+ model_list = []
189
+ for file in files:
190
+ path = Path(f"{dir_path}/{file}")
191
+ if path.suffix in ['.ckpt', '.pt', '.pth', '.safetensors', '.bin']:
192
+ model_list.append(str(path))
193
+ for model in model_list:
194
+ private_model_path_repo_dict[model] = repo_id
195
+ return model_list
196
+
197
+
198
+ def download_private_file(repo_id, path, is_replace):
199
+ from huggingface_hub import hf_hub_download
200
+ file = Path(path)
201
+ newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
202
+ if not hf_read_token or newpath.exists(): return
203
+ filename = file.name
204
+ dirname = file.parent.name
205
+ try:
206
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, use_auth_token=hf_read_token)
207
+ except Exception as e:
208
+ print(f"Error: Failed to download {filename}.")
209
+ print(e)
210
+ return
211
+ if is_replace:
212
+ file.resolve().rename(newpath.resolve())
213
+
214
+
215
+ def download_private_file_from_somewhere(path, is_replace):
216
+ if not path in private_model_path_repo_dict.keys(): return
217
+ repo_id = private_model_path_repo_dict.get(path, None)
218
+ download_private_file(repo_id, path, is_replace)
219
+
220
+
221
+ model_id_list = []
222
+ def get_model_id_list():
223
+ global model_id_list
224
+ if len(model_id_list) != 0: return model_id_list
225
+ api = HfApi()
226
+ model_ids = []
227
+ try:
228
+ models_likes = []
229
+ for author in HF_MODEL_USER_LIKES:
230
+ models_likes.extend(api.list_models(author=author, task="text-to-image", cardData=True, sort="likes"))
231
+ models_ex = []
232
+ for author in HF_MODEL_USER_EX:
233
+ models_ex = api.list_models(author=author, task="text-to-image", cardData=True, sort="last_modified")
234
+ except Exception as e:
235
+ print(f"Error: Failed to list {author}'s models.")
236
+ print(e)
237
+ return model_ids
238
+ for model in models_likes:
239
+ model_ids.append(model.id) if not model.private else ""
240
+ anime_models = []
241
+ real_models = []
242
+ anime_models_flux = []
243
+ real_models_flux = []
244
+ for model in models_ex:
245
+ if not model.private and not model.gated:
246
+ if "diffusers:FluxPipeline" in model.tags: anime_models_flux.append(model.id) if "anime" in model.tags else real_models_flux.append(model.id)
247
+ else: anime_models.append(model.id) if "anime" in model.tags else real_models.append(model.id)
248
+ model_ids.extend(anime_models)
249
+ model_ids.extend(real_models)
250
+ model_ids.extend(anime_models_flux)
251
+ model_ids.extend(real_models_flux)
252
+ model_id_list = model_ids.copy()
253
+ return model_ids
254
+
255
+
256
+ model_id_list = get_model_id_list()
257
+
258
+
259
+ def get_t2i_model_info(repo_id: str):
260
+ api = HfApi(token=HF_TOKEN)
261
+ try:
262
+ if not is_repo_name(repo_id): return ""
263
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
264
+ except Exception as e:
265
+ print(f"Error: Failed to get {repo_id}'s info.")
266
+ print(e)
267
+ return ""
268
+ if model.private or model.gated: return ""
269
+ tags = model.tags
270
+ info = []
271
+ url = f"https://huggingface.co/{repo_id}/"
272
+ if not 'diffusers' in tags: return ""
273
+ for k, v in MODEL_TYPE_DICT.items():
274
+ if k in tags: info.append(v)
275
+ if model.card_data and model.card_data.tags:
276
+ info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
277
+ info.append(f"DLs: {model.downloads}")
278
+ info.append(f"likes: {model.likes}")
279
+ info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
280
+ md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
281
+ return gr.update(value=md)
282
+
283
+
284
+ def get_tupled_model_list(model_list):
285
+ if not model_list: return []
286
+ tupled_list = []
287
+ for repo_id in model_list:
288
+ api = HfApi()
289
+ try:
290
+ if not api.repo_exists(repo_id): continue
291
+ model = api.model_info(repo_id=repo_id)
292
+ except Exception as e:
293
+ print(e)
294
+ continue
295
+ if model.private or model.gated: continue
296
+ tags = model.tags
297
+ info = []
298
+ if not 'diffusers' in tags: continue
299
+ for k, v in MODEL_TYPE_DICT.items():
300
+ if k in tags: info.append(v)
301
+ if model.card_data and model.card_data.tags:
302
+ info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
303
+ if "pony" in info:
304
+ info.remove("pony")
305
+ name = f"{repo_id} (Pony🐴, {', '.join(info)})"
306
+ else:
307
+ name = f"{repo_id} ({', '.join(info)})"
308
+ tupled_list.append((name, repo_id))
309
+ return tupled_list
310
+
311
+
312
+ private_lora_dict = {}
313
+ try:
314
+ with open('lora_dict.json', encoding='utf-8') as f:
315
+ d = json.load(f)
316
+ for k, v in d.items():
317
+ private_lora_dict[escape_lora_basename(k)] = v
318
+ except Exception as e:
319
+ print(e)
320
+ loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy()
321
+ civitai_not_exists_list = []
322
+ loras_url_to_path_dict = {} # {"URL to download": "local filepath", ...}
323
+ civitai_lora_last_results = {} # {"URL to download": {search results}, ...}
324
+ all_lora_list = []
325
+
326
+
327
+ private_lora_model_list = []
328
+ def get_private_lora_model_lists():
329
+ global private_lora_model_list
330
+ if len(private_lora_model_list) != 0: return private_lora_model_list
331
+ models1 = []
332
+ models2 = []
333
+ for repo in HF_LORA_PRIVATE_REPOS1:
334
+ models1.extend(get_private_model_list(repo, directory_loras))
335
+ for repo in HF_LORA_PRIVATE_REPOS2:
336
+ models2.extend(get_private_model_list(repo, directory_loras))
337
+ models = list_uniq(models1 + sorted(models2))
338
+ private_lora_model_list = models.copy()
339
+ return models
340
+
341
+
342
+ private_lora_model_list = get_private_lora_model_lists()
343
+
344
+
345
+ def get_civitai_info(path):
346
+ global civitai_not_exists_list
347
+ import requests
348
+ from urllib3.util import Retry
349
+ from requests.adapters import HTTPAdapter
350
+ if path in set(civitai_not_exists_list): return ["", "", "", "", ""]
351
+ if not Path(path).exists(): return None
352
+ user_agent = get_user_agent()
353
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
354
+ base_url = 'https://civitai.com/api/v1/model-versions/by-hash/'
355
+ params = {}
356
+ session = requests.Session()
357
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
358
+ session.mount("https://", HTTPAdapter(max_retries=retries))
359
+ import hashlib
360
+ with open(path, 'rb') as file:
361
+ file_data = file.read()
362
+ hash_sha256 = hashlib.sha256(file_data).hexdigest()
363
+ url = base_url + hash_sha256
364
+ try:
365
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
366
+ except Exception as e:
367
+ print(e)
368
+ return ["", "", "", "", ""]
369
+ if not r.ok: return None
370
+ json = r.json()
371
+ if not 'baseModel' in json:
372
+ civitai_not_exists_list.append(path)
373
+ return ["", "", "", "", ""]
374
+ items = []
375
+ items.append(" / ".join(json['trainedWords']))
376
+ items.append(json['baseModel'])
377
+ items.append(json['model']['name'])
378
+ items.append(f"https://civitai.com/models/{json['modelId']}")
379
+ items.append(json['images'][0]['url'])
380
+ return items
381
+
382
+
383
+ def get_lora_model_list():
384
+ loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras) + DIFFUSERS_FORMAT_LORAS)
385
+ loras.insert(0, "None")
386
+ loras.insert(0, "")
387
+ return loras
388
+
389
+
390
+ def get_all_lora_list():
391
+ global all_lora_list
392
+ loras = get_lora_model_list()
393
+ all_lora_list = loras.copy()
394
+ return loras
395
+
396
+
397
+ def get_all_lora_tupled_list():
398
+ global loras_dict
399
+ models = get_all_lora_list()
400
+ if not models: return []
401
+ tupled_list = []
402
+ for model in models:
403
+ #if not model: continue # to avoid GUI-related bug
404
+ basename = Path(model).stem
405
+ key = to_lora_key(model)
406
+ items = None
407
+ if key in loras_dict.keys():
408
+ items = loras_dict.get(key, None)
409
+ else:
410
+ items = get_civitai_info(model)
411
+ if items != None:
412
+ loras_dict[key] = items
413
+ name = basename
414
+ value = model
415
+ if items and items[2] != "":
416
+ if items[1] == "Pony":
417
+ name = f"{basename} (for {items[1]}🐴, {items[2]})"
418
+ else:
419
+ name = f"{basename} (for {items[1]}, {items[2]})"
420
+ tupled_list.append((name, value))
421
+ return tupled_list
422
+
423
+
424
+ def update_lora_dict(path):
425
+ global loras_dict
426
+ key = escape_lora_basename(Path(path).stem)
427
+ if key in loras_dict.keys(): return
428
+ items = get_civitai_info(path)
429
+ if items == None: return
430
+ loras_dict[key] = items
431
+
432
+
433
+ def download_lora(dl_urls: str):
434
+ global loras_url_to_path_dict
435
+ dl_path = ""
436
+ before = get_local_model_list(directory_loras)
437
+ urls = []
438
+ for url in [url.strip() for url in dl_urls.split(',')]:
439
+ local_path = f"{directory_loras}/{url.split('/')[-1]}"
440
+ if not Path(local_path).exists():
441
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
442
+ urls.append(url)
443
+ after = get_local_model_list(directory_loras)
444
+ new_files = list_sub(after, before)
445
+ i = 0
446
+ for file in new_files:
447
+ path = Path(file)
448
+ if path.exists():
449
+ new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
450
+ path.resolve().rename(new_path.resolve())
451
+ loras_url_to_path_dict[urls[i]] = str(new_path)
452
+ update_lora_dict(str(new_path))
453
+ dl_path = str(new_path)
454
+ i += 1
455
+ return dl_path
456
+
457
+
458
+ def copy_lora(path: str, new_path: str):
459
+ import shutil
460
+ if path == new_path: return new_path
461
+ cpath = Path(path)
462
+ npath = Path(new_path)
463
+ if cpath.exists():
464
+ try:
465
+ shutil.copy(str(cpath.resolve()), str(npath.resolve()))
466
+ except Exception as e:
467
+ print(e)
468
+ return None
469
+ update_lora_dict(str(npath))
470
+ return new_path
471
+ else:
472
+ return None
473
+
474
+
475
+ def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: str, lora5: str):
476
+ path = download_lora(dl_urls)
477
+ if path:
478
+ if not lora1 or lora1 == "None":
479
+ lora1 = path
480
+ elif not lora2 or lora2 == "None":
481
+ lora2 = path
482
+ elif not lora3 or lora3 == "None":
483
+ lora3 = path
484
+ elif not lora4 or lora4 == "None":
485
+ lora4 = path
486
+ elif not lora5 or lora5 == "None":
487
+ lora5 = path
488
+ choices = get_all_lora_tupled_list()
489
+ return gr.update(value=lora1, choices=choices), gr.update(value=lora2, choices=choices), gr.update(value=lora3, choices=choices),\
490
+ gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
491
+
492
+
493
+ def get_valid_lora_name(query: str, model_name: str):
494
+ path = "None"
495
+ if not query or query == "None": return "None"
496
+ if to_lora_key(query) in loras_dict.keys(): return query
497
+ if query in loras_url_to_path_dict.keys():
498
+ path = loras_url_to_path_dict[query]
499
+ else:
500
+ path = to_lora_path(query.strip().split('/')[-1])
501
+ if Path(path).exists():
502
+ return path
503
+ elif "http" in query:
504
+ dl_file = download_lora(query)
505
+ if dl_file and Path(dl_file).exists(): return dl_file
506
+ else:
507
+ dl_file = find_similar_lora(query, model_name)
508
+ if dl_file and Path(dl_file).exists(): return dl_file
509
+ return "None"
510
+
511
+
512
+ def get_valid_lora_path(query: str):
513
+ path = None
514
+ if not query or query == "None": return None
515
+ if to_lora_key(query) in loras_dict.keys(): return query
516
+ if Path(path).exists():
517
+ return path
518
+ else:
519
+ return None
520
+
521
+
522
+ def get_valid_lora_wt(prompt: str, lora_path: str, lora_wt: float):
523
+ import re
524
+ wt = lora_wt
525
+ result = re.findall(f'<lora:{to_lora_key(lora_path)}:(.+?)>', prompt)
526
+ if not result: return wt
527
+ wt = safe_float(result[0][0])
528
+ return wt
529
+
530
+
531
+ def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
532
+ import re
533
+ if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
534
+ lora1 = get_valid_lora_name(lora1, model_name)
535
+ lora2 = get_valid_lora_name(lora2, model_name)
536
+ lora3 = get_valid_lora_name(lora3, model_name)
537
+ lora4 = get_valid_lora_name(lora4, model_name)
538
+ lora5 = get_valid_lora_name(lora5, model_name)
539
+ if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
540
+ lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
541
+ lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
542
+ lora3_wt = get_valid_lora_wt(prompt, lora3, lora3_wt)
543
+ lora4_wt = get_valid_lora_wt(prompt, lora4, lora4_wt)
544
+ lora5_wt = get_valid_lora_wt(prompt, lora5, lora5_wt)
545
+ on1, label1, tag1, md1 = get_lora_info(lora1)
546
+ on2, label2, tag2, md2 = get_lora_info(lora2)
547
+ on3, label3, tag3, md3 = get_lora_info(lora3)
548
+ on4, label4, tag4, md4 = get_lora_info(lora4)
549
+ on5, label5, tag5, md5 = get_lora_info(lora5)
550
+ lora_paths = [lora1, lora2, lora3, lora4, lora5]
551
+ prompts = prompt.split(",") if prompt else []
552
+ for p in prompts:
553
+ p = str(p).strip()
554
+ if "<lora" in p:
555
+ result = re.findall(r'<lora:(.+?):(.+?)>', p)
556
+ if not result: continue
557
+ key = result[0][0]
558
+ wt = result[0][1]
559
+ path = to_lora_path(key)
560
+ if not key in loras_dict.keys() or not path:
561
+ path = get_valid_lora_name(path)
562
+ if not path or path == "None": continue
563
+ if path in lora_paths:
564
+ continue
565
+ elif not on1:
566
+ lora1 = path
567
+ lora_paths = [lora1, lora2, lora3, lora4, lora5]
568
+ lora1_wt = safe_float(wt)
569
+ on1 = True
570
+ elif not on2:
571
+ lora2 = path
572
+ lora_paths = [lora1, lora2, lora3, lora4, lora5]
573
+ lora2_wt = safe_float(wt)
574
+ on2 = True
575
+ elif not on3:
576
+ lora3 = path
577
+ lora_paths = [lora1, lora2, lora3, lora4, lora5]
578
+ lora3_wt = safe_float(wt)
579
+ on3 = True
580
+ elif not on4:
581
+ lora4 = path
582
+ lora_paths = [lora1, lora2, lora3, lora4, lora5]
583
+ lora4_wt = safe_float(wt)
584
+ on4, label4, tag4, md4 = get_lora_info(lora4)
585
+ elif not on5:
586
+ lora5 = path
587
+ lora_paths = [lora1, lora2, lora3, lora4, lora5]
588
+ lora5_wt = safe_float(wt)
589
+ on5 = True
590
+ return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
591
+
592
+
593
+ def get_lora_info(lora_path: str):
594
+ is_valid = False
595
+ tag = ""
596
+ label = ""
597
+ md = "None"
598
+ if not lora_path or lora_path == "None":
599
+ print("LoRA file not found.")
600
+ return is_valid, label, tag, md
601
+ path = Path(lora_path)
602
+ new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
603
+ if not to_lora_key(str(new_path)) in loras_dict.keys() and str(path) not in set(get_all_lora_list()):
604
+ print("LoRA file is not registered.")
605
+ return tag, label, tag, md
606
+ if not new_path.exists():
607
+ download_private_file_from_somewhere(str(path), True)
608
+ basename = new_path.stem
609
+ label = f'Name: {basename}'
610
+ items = loras_dict.get(basename, None)
611
+ if items == None:
612
+ items = get_civitai_info(str(new_path))
613
+ if items != None:
614
+ loras_dict[basename] = items
615
+ if items and items[2] != "":
616
+ tag = items[0]
617
+ label = f'Name: {basename}'
618
+ if items[1] == "Pony":
619
+ label = f'Name: {basename} (for Pony🐴)'
620
+ if items[4]:
621
+ md = f'<img src="{items[4]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL]({items[3]})'
622
+ elif items[3]:
623
+ md = f'[LoRA Model URL]({items[3]})'
624
+ is_valid = True
625
+ return is_valid, label, tag, md
626
+
627
+
628
+ def normalize_prompt_list(tags: list[str]):
629
+ prompts = []
630
+ for tag in tags:
631
+ tag = str(tag).strip()
632
+ if tag:
633
+ prompts.append(tag)
634
+ return prompts
635
+
636
+
637
+ def apply_lora_prompt(prompt: str = "", lora_info: str = ""):
638
+ if lora_info == "None": return gr.update(value=prompt)
639
+ tags = prompt.split(",") if prompt else []
640
+ prompts = normalize_prompt_list(tags)
641
+
642
+ lora_tag = lora_info.replace("/",",")
643
+ lora_tags = lora_tag.split(",") if str(lora_info) != "None" else []
644
+ lora_prompts = normalize_prompt_list(lora_tags)
645
+
646
+ empty = [""]
647
+ prompt = ", ".join(list_uniq(prompts + lora_prompts) + empty)
648
+ return gr.update(value=prompt)
649
+
650
+
651
+ def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
652
+ import re
653
+ on1, label1, tag1, md1 = get_lora_info(lora1)
654
+ on2, label2, tag2, md2 = get_lora_info(lora2)
655
+ on3, label3, tag3, md3 = get_lora_info(lora3)
656
+ on4, label4, tag4, md4 = get_lora_info(lora4)
657
+ on5, label5, tag5, md5 = get_lora_info(lora5)
658
+ lora_paths = [lora1, lora2, lora3, lora4, lora5]
659
+
660
+ output_prompt = prompt
661
+ if "Classic" in str(prompt_syntax):
662
+ prompts = prompt.split(",") if prompt else []
663
+ output_prompts = []
664
+ for p in prompts:
665
+ p = str(p).strip()
666
+ if "<lora" in p:
667
+ result = re.findall(r'<lora:(.+?):(.+?)>', p)
668
+ if not result: continue
669
+ key = result[0][0]
670
+ wt = result[0][1]
671
+ path = to_lora_path(key)
672
+ if not key in loras_dict.keys() or not path: continue
673
+ if path in lora_paths:
674
+ output_prompts.append(f"<lora:{to_lora_key(path)}:{safe_float(wt):.2f}>")
675
+ elif p:
676
+ output_prompts.append(p)
677
+ lora_prompts = []
678
+ if on1: lora_prompts.append(f"<lora:{to_lora_key(lora1)}:{lora1_wt:.2f}>")
679
+ if on2: lora_prompts.append(f"<lora:{to_lora_key(lora2)}:{lora2_wt:.2f}>")
680
+ if on3: lora_prompts.append(f"<lora:{to_lora_key(lora3)}:{lora3_wt:.2f}>")
681
+ if on4: lora_prompts.append(f"<lora:{to_lora_key(lora4)}:{lora4_wt:.2f}>")
682
+ if on5: lora_prompts.append(f"<lora:{to_lora_key(lora5)}:{lora5_wt:.2f}>")
683
+ output_prompt = ", ".join(list_uniq(output_prompts + lora_prompts + [""]))
684
+ choices = get_all_lora_tupled_list()
685
+
686
+ return gr.update(value=output_prompt), gr.update(value=lora1, choices=choices), gr.update(value=lora1_wt),\
687
+ gr.update(value=tag1, label=label1, visible=on1), gr.update(visible=on1), gr.update(value=md1, visible=on1),\
688
+ gr.update(value=lora2, choices=choices), gr.update(value=lora2_wt),\
689
+ gr.update(value=tag2, label=label2, visible=on2), gr.update(visible=on2), gr.update(value=md2, visible=on2),\
690
+ gr.update(value=lora3, choices=choices), gr.update(value=lora3_wt),\
691
+ gr.update(value=tag3, label=label3, visible=on3), gr.update(visible=on3), gr.update(value=md3, visible=on3),\
692
+ gr.update(value=lora4, choices=choices), gr.update(value=lora4_wt),\
693
+ gr.update(value=tag4, label=label4, visible=on4), gr.update(visible=on4), gr.update(value=md4, visible=on4),\
694
+ gr.update(value=lora5, choices=choices), gr.update(value=lora5_wt),\
695
+ gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
696
+
697
+
698
+ def get_my_lora(link_url):
699
+ from pathlib import Path
700
+ before = get_local_model_list(directory_loras)
701
+ for url in [url.strip() for url in link_url.split(',')]:
702
+ if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
703
+ download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
704
+ after = get_local_model_list(directory_loras)
705
+ new_files = list_sub(after, before)
706
+ for file in new_files:
707
+ path = Path(file)
708
+ if path.exists():
709
+ new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
710
+ path.resolve().rename(new_path.resolve())
711
+ update_lora_dict(str(new_path))
712
+ new_lora_model_list = get_lora_model_list()
713
+ new_lora_tupled_list = get_all_lora_tupled_list()
714
+
715
+ return gr.update(
716
+ choices=new_lora_tupled_list, value=new_lora_model_list[-1]
717
+ ), gr.update(
718
+ choices=new_lora_tupled_list
719
+ ), gr.update(
720
+ choices=new_lora_tupled_list
721
+ ), gr.update(
722
+ choices=new_lora_tupled_list
723
+ ), gr.update(
724
+ choices=new_lora_tupled_list
725
+ )
726
+
727
+
728
+ def upload_file_lora(files, progress=gr.Progress(track_tqdm=True)):
729
+ progress(0, desc="Uploading...")
730
+ file_paths = [file.name for file in files]
731
+ progress(1, desc="Uploaded.")
732
+ return gr.update(value=file_paths, visible=True), gr.update(visible=True)
733
+
734
+
735
+ def move_file_lora(filepaths):
736
+ import shutil
737
+ for file in filepaths:
738
+ path = Path(shutil.move(Path(file).resolve(), Path(f"./{directory_loras}").resolve()))
739
+ newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
740
+ path.resolve().rename(newpath.resolve())
741
+ update_lora_dict(str(newpath))
742
+
743
+ new_lora_model_list = get_lora_model_list()
744
+ new_lora_tupled_list = get_all_lora_tupled_list()
745
+
746
+ return gr.update(
747
+ choices=new_lora_tupled_list, value=new_lora_model_list[-1]
748
+ ), gr.update(
749
+ choices=new_lora_tupled_list
750
+ ), gr.update(
751
+ choices=new_lora_tupled_list
752
+ ), gr.update(
753
+ choices=new_lora_tupled_list
754
+ ), gr.update(
755
+ choices=new_lora_tupled_list
756
+ )
757
+
758
+
759
+ def get_civitai_info(path):
760
+ global civitai_not_exists_list, loras_url_to_path_dict
761
+ import requests
762
+ from requests.adapters import HTTPAdapter
763
+ from urllib3.util import Retry
764
+ default = ["", "", "", "", ""]
765
+ if path in set(civitai_not_exists_list): return default
766
+ if not Path(path).exists(): return None
767
+ user_agent = get_user_agent()
768
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
769
+ base_url = 'https://civitai.com/api/v1/model-versions/by-hash/'
770
+ params = {}
771
+ session = requests.Session()
772
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
773
+ session.mount("https://", HTTPAdapter(max_retries=retries))
774
+ import hashlib
775
+ with open(path, 'rb') as file:
776
+ file_data = file.read()
777
+ hash_sha256 = hashlib.sha256(file_data).hexdigest()
778
+ url = base_url + hash_sha256
779
+ try:
780
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
781
+ except Exception as e:
782
+ print(e)
783
+ return default
784
+ else:
785
+ if not r.ok: return None
786
+ json = r.json()
787
+ if 'baseModel' not in json:
788
+ civitai_not_exists_list.append(path)
789
+ return default
790
+ items = []
791
+ items.append(" / ".join(json['trainedWords'])) # The words (prompts) used to trigger the model
792
+ items.append(json['baseModel']) # Base model (SDXL1.0, Pony, ...)
793
+ items.append(json['model']['name']) # The name of the model version
794
+ items.append(f"https://civitai.com/models/{json['modelId']}") # The repo url for the model
795
+ items.append(json['images'][0]['url']) # The url for a sample image
796
+ loras_url_to_path_dict[path] = json['downloadUrl'] # The download url to get the model file for this specific version
797
+ return items
798
+
799
+
800
+ def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100,
801
+ sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):
802
+ import requests
803
+ from requests.adapters import HTTPAdapter
804
+ from urllib3.util import Retry
805
+ user_agent = get_user_agent()
806
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
807
+ base_url = 'https://civitai.com/api/v1/models'
808
+ params = {'types': ['LORA'], 'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
809
+ if query: params["query"] = query
810
+ if tag: params["tag"] = tag
811
+ session = requests.Session()
812
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
813
+ session.mount("https://", HTTPAdapter(max_retries=retries))
814
+ try:
815
+ r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(3.0, 30))
816
+ except Exception as e:
817
+ print(e)
818
+ return None
819
+ else:
820
+ if not r.ok: return None
821
+ json = r.json()
822
+ if 'items' not in json: return None
823
+ items = []
824
+ for j in json['items']:
825
+ for model in j['modelVersions']:
826
+ item = {}
827
+ if model['baseModel'] not in set(allow_model): continue
828
+ item['name'] = j['name']
829
+ item['creator'] = j['creator']['username']
830
+ item['tags'] = j['tags']
831
+ item['model_name'] = model['name']
832
+ item['base_model'] = model['baseModel']
833
+ item['dl_url'] = model['downloadUrl']
834
+ item['md'] = f'<img src="{model["images"][0]["url"]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL](https://civitai.com/models/{j["id"]})'
835
+ items.append(item)
836
+ return items
837
+
838
+
839
+ def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
840
+ global civitai_lora_last_results
841
+ items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
842
+ if not items: return gr.update(choices=[("", "")], value="", visible=False),\
843
+ gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
844
+ civitai_lora_last_results = {}
845
+ choices = []
846
+ for item in items:
847
+ base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
848
+ name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
849
+ value = item['dl_url']
850
+ choices.append((name, value))
851
+ civitai_lora_last_results[value] = item
852
+ if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
853
+ gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
854
+ result = civitai_lora_last_results.get(choices[0][1], "None")
855
+ md = result['md'] if result else ""
856
+ return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
857
+ gr.update(visible=True), gr.update(visible=True)
858
+
859
+
860
+ def select_civitai_lora(search_result):
861
+ if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
862
+ result = civitai_lora_last_results.get(search_result, "None")
863
+ md = result['md'] if result else ""
864
+ return gr.update(value=search_result), gr.update(value=md, visible=True)
865
+
866
+
867
+ LORA_BASE_MODEL_DICT = {
868
+ "diffusers:StableDiffusionPipeline": ["SD 1.5"],
869
+ "diffusers:StableDiffusionXLPipeline": ["Pony", "SDXL 1.0"],
870
+ "diffusers:FluxPipeline": ["Flux.1 D", "Flux.1 S"],
871
+ }
872
+
873
+
874
+ def get_lora_base_model(model_name: str):
875
+ api = HfApi(token=HF_TOKEN)
876
+ default = ["Pony", "SDXL 1.0"]
877
+ try:
878
+ model = api.model_info(repo_id=model_name, timeout=5.0)
879
+ tags = model.tags
880
+ for tag in tags:
881
+ if tag in LORA_BASE_MODEL_DICT.keys(): return LORA_BASE_MODEL_DICT.get(tag, default)
882
+ except Exception:
883
+ return default
884
+ return default
885
+
886
+
887
+ def find_similar_lora(q: str, model_name: str):
888
+ from rapidfuzz.process import extractOne
889
+ from rapidfuzz.utils import default_process
890
+ query = to_lora_key(q)
891
+ print(f"Finding <lora:{query}:...>...")
892
+ keys = list(private_lora_dict.keys())
893
+ values = [x[2] for x in list(private_lora_dict.values())]
894
+ s = default_process(query)
895
+ e1 = extractOne(s, keys + values, processor=default_process, score_cutoff=80.0)
896
+ key = ""
897
+ if e1:
898
+ e = e1[0]
899
+ if e in set(keys): key = e
900
+ elif e in set(values): key = keys[values.index(e)]
901
+ if key:
902
+ path = to_lora_path(key)
903
+ new_path = to_lora_path(query)
904
+ if not Path(path).exists():
905
+ if not Path(new_path).exists(): download_private_file_from_somewhere(path, True)
906
+ if Path(path).exists() and copy_lora(path, new_path): return new_path
907
+ print(f"Finding <lora:{query}:...> on Civitai...")
908
+ civitai_query = Path(query).stem if Path(query).is_file() else query
909
+ civitai_query = civitai_query.replace("_", " ").replace("-", " ")
910
+ base_model = get_lora_base_model(model_name)
911
+ items = search_lora_on_civitai(civitai_query, base_model, 1)
912
+ if items:
913
+ item = items[0]
914
+ path = download_lora(item['dl_url'])
915
+ new_path = query if Path(query).is_file() else to_lora_path(query)
916
+ if path and copy_lora(path, new_path): return new_path
917
+ return None
918
+
919
+
920
+ def change_interface_mode(mode: str):
921
+ if mode == "Fast":
922
+ return gr.update(open=False), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
923
+ gr.update(visible=True), gr.update(open=False), gr.update(visible=True), gr.update(open=False),\
924
+ gr.update(visible=True), gr.update(value="Fast")
925
+ elif mode == "Simple": # t2i mode
926
+ return gr.update(open=True), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
927
+ gr.update(visible=True), gr.update(open=False), gr.update(visible=False), gr.update(open=True),\
928
+ gr.update(visible=False), gr.update(value="Standard")
929
+ elif mode == "LoRA": # t2i LoRA mode
930
+ return gr.update(open=True), gr.update(visible=True), gr.update(open=True), gr.update(open=False),\
931
+ gr.update(visible=True), gr.update(open=True), gr.update(visible=True), gr.update(open=False),\
932
+ gr.update(visible=False), gr.update(value="Standard")
933
+ else: # Standard
934
+ return gr.update(open=False), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
935
+ gr.update(visible=True), gr.update(open=False), gr.update(visible=True), gr.update(open=False),\
936
+ gr.update(visible=True), gr.update(value="Standard")
937
+
938
+
939
+ quality_prompt_list = [
940
+ {
941
+ "name": "None",
942
+ "prompt": "",
943
+ "negative_prompt": "lowres",
944
+ },
945
+ {
946
+ "name": "Animagine Common",
947
+ "prompt": "anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
948
+ "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
949
+ },
950
+ {
951
+ "name": "Pony Anime Common",
952
+ "prompt": "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres",
953
+ "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
954
+ },
955
+ {
956
+ "name": "Pony Common",
957
+ "prompt": "source_anime, score_9, score_8_up, score_7_up",
958
+ "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
959
+ },
960
+ {
961
+ "name": "Animagine Standard v3.0",
962
+ "prompt": "masterpiece, best quality",
963
+ "negative_prompt": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
964
+ },
965
+ {
966
+ "name": "Animagine Standard v3.1",
967
+ "prompt": "masterpiece, best quality, very aesthetic, absurdres",
968
+ "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
969
+ },
970
+ {
971
+ "name": "Animagine Light v3.1",
972
+ "prompt": "(masterpiece), best quality, very aesthetic, perfect face",
973
+ "negative_prompt": "(low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
974
+ },
975
+ {
976
+ "name": "Animagine Heavy v3.1",
977
+ "prompt": "(masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
978
+ "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
979
+ },
980
+ ]
981
+
982
+
983
+ style_list = [
984
+ {
985
+ "name": "None",
986
+ "prompt": "",
987
+ "negative_prompt": "",
988
+ },
989
+ {
990
+ "name": "Cinematic",
991
+ "prompt": "cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
992
+ "negative_prompt": "cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
993
+ },
994
+ {
995
+ "name": "Photographic",
996
+ "prompt": "cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
997
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
998
+ },
999
+ {
1000
+ "name": "Anime",
1001
+ "prompt": "anime artwork, anime style, vibrant, studio anime, highly detailed",
1002
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
1003
+ },
1004
+ {
1005
+ "name": "Manga",
1006
+ "prompt": "manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
1007
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
1008
+ },
1009
+ {
1010
+ "name": "Digital Art",
1011
+ "prompt": "concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
1012
+ "negative_prompt": "photo, photorealistic, realism, ugly",
1013
+ },
1014
+ {
1015
+ "name": "Pixel art",
1016
+ "prompt": "pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
1017
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
1018
+ },
1019
+ {
1020
+ "name": "Fantasy art",
1021
+ "prompt": "ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
1022
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
1023
+ },
1024
+ {
1025
+ "name": "Neonpunk",
1026
+ "prompt": "neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
1027
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
1028
+ },
1029
+ {
1030
+ "name": "3D Model",
1031
+ "prompt": "professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
1032
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
1033
+ },
1034
+ ]
1035
+
1036
+
1037
+ optimization_list = {
1038
+ "None": [28, 7., 'Euler a', False, 'None', 1.],
1039
+ "Default": [28, 7., 'Euler a', False, 'None', 1.],
1040
+ "SPO": [28, 7., 'Euler a', True, 'loras/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors', 1.],
1041
+ "DPO": [28, 7., 'Euler a', True, 'loras/sdxl-DPO-LoRA.safetensors', 1.],
1042
+ "DPO Turbo": [8, 2.5, 'LCM', True, 'loras/sd_xl_dpo_turbo_lora_v1-128dim.safetensors', 1.],
1043
+ "SDXL Turbo": [8, 2.5, 'LCM', True, 'loras/sd_xl_turbo_lora_v1.safetensors', 1.],
1044
+ "Hyper-SDXL 12step": [12, 5., 'TCD', True, 'loras/Hyper-SDXL-12steps-CFG-lora.safetensors', 1.],
1045
+ "Hyper-SDXL 8step": [8, 5., 'TCD', True, 'loras/Hyper-SDXL-8steps-CFG-lora.safetensors', 1.],
1046
+ "Hyper-SDXL 4step": [4, 0, 'TCD', True, 'loras/Hyper-SDXL-4steps-lora.safetensors', 1.],
1047
+ "Hyper-SDXL 2step": [2, 0, 'TCD', True, 'loras/Hyper-SDXL-2steps-lora.safetensors', 1.],
1048
+ "Hyper-SDXL 1step": [1, 0, 'TCD', True, 'loras/Hyper-SDXL-1steps-lora.safetensors', 1.],
1049
+ "PCM 16step": [16, 4., 'Euler a trailing', True, 'loras/pcm_sdxl_normalcfg_16step_converted.safetensors', 1.],
1050
+ "PCM 8step": [8, 4., 'Euler a trailing', True, 'loras/pcm_sdxl_normalcfg_8step_converted.safetensors', 1.],
1051
+ "PCM 4step": [4, 2., 'Euler a trailing', True, 'loras/pcm_sdxl_smallcfg_4step_converted.safetensors', 1.],
1052
+ "PCM 2step": [2, 1., 'Euler a trailing', True, 'loras/pcm_sdxl_smallcfg_2step_converted.safetensors', 1.],
1053
+ }
1054
+
1055
+
1056
+ def set_optimization(opt, steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora_gui, lora_scale_gui):
1057
+ if not opt in list(optimization_list.keys()): opt = "None"
1058
+ def_steps_gui = 28
1059
+ def_cfg_gui = 7.
1060
+ steps = optimization_list.get(opt, "None")[0]
1061
+ cfg = optimization_list.get(opt, "None")[1]
1062
+ sampler = optimization_list.get(opt, "None")[2]
1063
+ clip_skip = optimization_list.get(opt, "None")[3]
1064
+ lora = optimization_list.get(opt, "None")[4]
1065
+ lora_scale = optimization_list.get(opt, "None")[5]
1066
+ if opt == "None":
1067
+ steps = max(steps_gui, def_steps_gui)
1068
+ cfg = max(cfg_gui, def_cfg_gui)
1069
+ clip_skip = clip_skip_gui
1070
+ elif opt == "SPO" or opt == "DPO":
1071
+ steps = max(steps_gui, def_steps_gui)
1072
+ cfg = max(cfg_gui, def_cfg_gui)
1073
+
1074
+ return gr.update(value=steps), gr.update(value=cfg), gr.update(value=sampler),\
1075
+ gr.update(value=clip_skip), gr.update(value=lora), gr.update(value=lora_scale),
1076
+
1077
+
1078
+ # [sampler_gui, steps_gui, cfg_gui, clip_skip_gui, img_width_gui, img_height_gui, optimization_gui]
1079
+ preset_sampler_setting = {
1080
+ "None": ["Euler a", 28, 7., True, 1024, 1024, "None"],
1081
+ "Anime 3:4 Fast": ["LCM", 8, 2.5, True, 896, 1152, "DPO Turbo"],
1082
+ "Anime 3:4 Standard": ["Euler a", 28, 7., True, 896, 1152, "None"],
1083
+ "Anime 3:4 Heavy": ["Euler a", 40, 7., True, 896, 1152, "None"],
1084
+ "Anime 1:1 Fast": ["LCM", 8, 2.5, True, 1024, 1024, "DPO Turbo"],
1085
+ "Anime 1:1 Standard": ["Euler a", 28, 7., True, 1024, 1024, "None"],
1086
+ "Anime 1:1 Heavy": ["Euler a", 40, 7., True, 1024, 1024, "None"],
1087
+ "Photo 3:4 Fast": ["LCM", 8, 2.5, False, 896, 1152, "DPO Turbo"],
1088
+ "Photo 3:4 Standard": ["DPM++ 2M Karras", 28, 7., False, 896, 1152, "None"],
1089
+ "Photo 3:4 Heavy": ["DPM++ 2M Karras", 40, 7., False, 896, 1152, "None"],
1090
+ "Photo 1:1 Fast": ["LCM", 8, 2.5, False, 1024, 1024, "DPO Turbo"],
1091
+ "Photo 1:1 Standard": ["DPM++ 2M Karras", 28, 7., False, 1024, 1024, "None"],
1092
+ "Photo 1:1 Heavy": ["DPM++ 2M Karras", 40, 7., False, 1024, 1024, "None"],
1093
+ }
1094
+
1095
+
1096
+ def set_sampler_settings(sampler_setting):
1097
+ if not sampler_setting in list(preset_sampler_setting.keys()) or sampler_setting == "None":
1098
+ return gr.update(value="Euler a"), gr.update(value=28), gr.update(value=7.), gr.update(value=True),\
1099
+ gr.update(value=1024), gr.update(value=1024), gr.update(value="None")
1100
+ v = preset_sampler_setting.get(sampler_setting, ["Euler a", 28, 7., True, 1024, 1024])
1101
+ # sampler, steps, cfg, clip_skip, width, height, optimization
1102
+ return gr.update(value=v[0]), gr.update(value=v[1]), gr.update(value=v[2]), gr.update(value=v[3]),\
1103
+ gr.update(value=v[4]), gr.update(value=v[5]), gr.update(value=v[6])
1104
+
1105
+
1106
+ preset_styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
1107
+ preset_quality = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
1108
+
1109
+
1110
+ def process_style_prompt(prompt: str, neg_prompt: str, styles_key: str = "None", quality_key: str = "None", type: str = "Auto"):
1111
+ def to_list(s):
1112
+ return [x.strip() for x in s.split(",") if not s == ""]
1113
+
1114
+ def list_sub(a, b):
1115
+ return [e for e in a if e not in b]
1116
+
1117
+ def list_uniq(l):
1118
+ return sorted(set(l), key=l.index)
1119
+
1120
+ animagine_ps = to_list("anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
1121
+ animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
1122
+ pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
1123
+ pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
1124
+ prompts = to_list(prompt)
1125
+ neg_prompts = to_list(neg_prompt)
1126
+
1127
+ all_styles_ps = []
1128
+ all_styles_nps = []
1129
+ for d in style_list:
1130
+ all_styles_ps.extend(to_list(str(d.get("prompt", ""))))
1131
+ all_styles_nps.extend(to_list(str(d.get("negative_prompt", ""))))
1132
+
1133
+ all_quality_ps = []
1134
+ all_quality_nps = []
1135
+ for d in quality_prompt_list:
1136
+ all_quality_ps.extend(to_list(str(d.get("prompt", ""))))
1137
+ all_quality_nps.extend(to_list(str(d.get("negative_prompt", ""))))
1138
+
1139
+ quality_ps = to_list(preset_quality[quality_key][0])
1140
+ quality_nps = to_list(preset_quality[quality_key][1])
1141
+ styles_ps = to_list(preset_styles[styles_key][0])
1142
+ styles_nps = to_list(preset_styles[styles_key][1])
1143
+
1144
+ prompts = list_sub(prompts, animagine_ps + pony_ps + all_styles_ps + all_quality_ps)
1145
+ neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + all_styles_nps + all_quality_nps)
1146
+
1147
+ last_empty_p = [""] if not prompts and type != "None" and type != "Auto" and styles_key != "None" and quality_key != "None" else []
1148
+ last_empty_np = [""] if not neg_prompts and type != "None" and type != "Auto" and styles_key != "None" and quality_key != "None" else []
1149
+
1150
+ if type == "Animagine":
1151
+ prompts = prompts + animagine_ps
1152
+ neg_prompts = neg_prompts + animagine_nps
1153
+ elif type == "Pony":
1154
+ prompts = prompts + pony_ps
1155
+ neg_prompts = neg_prompts + pony_nps
1156
+
1157
+ prompts = prompts + styles_ps + quality_ps
1158
+ neg_prompts = neg_prompts + styles_nps + quality_nps
1159
+
1160
+ prompt = ", ".join(list_uniq(prompts) + last_empty_p)
1161
+ neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
1162
+
1163
+ return gr.update(value=prompt), gr.update(value=neg_prompt), gr.update(value=type)
1164
+
1165
+
1166
+ def set_quick_presets(genre:str = "None", type:str = "Auto", speed:str = "None", aspect:str = "None"):
1167
+ quality = "None"
1168
+ style = "None"
1169
+ sampler = "None"
1170
+ opt = "None"
1171
+
1172
+ if genre == "Anime":
1173
+ if type != "None" and type != "Auto": style = "Anime"
1174
+ if aspect == "1:1":
1175
+ if speed == "Heavy":
1176
+ sampler = "Anime 1:1 Heavy"
1177
+ elif speed == "Fast":
1178
+ sampler = "Anime 1:1 Fast"
1179
+ else:
1180
+ sampler = "Anime 1:1 Standard"
1181
+ elif aspect == "3:4":
1182
+ if speed == "Heavy":
1183
+ sampler = "Anime 3:4 Heavy"
1184
+ elif speed == "Fast":
1185
+ sampler = "Anime 3:4 Fast"
1186
+ else:
1187
+ sampler = "Anime 3:4 Standard"
1188
+ if type == "Pony":
1189
+ quality = "Pony Anime Common"
1190
+ elif type == "Animagine":
1191
+ quality = "Animagine Common"
1192
+ else:
1193
+ quality = "None"
1194
+ elif genre == "Photo":
1195
+ if type != "None" and type != "Auto": style = "Photographic"
1196
+ if aspect == "1:1":
1197
+ if speed == "Heavy":
1198
+ sampler = "Photo 1:1 Heavy"
1199
+ elif speed == "Fast":
1200
+ sampler = "Photo 1:1 Fast"
1201
+ else:
1202
+ sampler = "Photo 1:1 Standard"
1203
+ elif aspect == "3:4":
1204
+ if speed == "Heavy":
1205
+ sampler = "Photo 3:4 Heavy"
1206
+ elif speed == "Fast":
1207
+ sampler = "Photo 3:4 Fast"
1208
+ else:
1209
+ sampler = "Photo 3:4 Standard"
1210
+ if type == "Pony":
1211
+ quality = "Pony Common"
1212
+ else:
1213
+ quality = "None"
1214
+
1215
+ if speed == "Fast":
1216
+ opt = "DPO Turbo"
1217
+ if genre == "Anime" and type != "Pony" and type != "Auto": quality = "Animagine Light v3.1"
1218
+
1219
+ return gr.update(value=quality), gr.update(value=style), gr.update(value=sampler), gr.update(value=opt), gr.update(value=type)
1220
+
1221
+
1222
+ textual_inversion_dict = {}
1223
+ try:
1224
+ with open('textual_inversion_dict.json', encoding='utf-8') as f:
1225
+ textual_inversion_dict = json.load(f)
1226
+ except Exception:
1227
+ pass
1228
+ textual_inversion_file_token_list = []
1229
+
1230
+
1231
+ def get_tupled_embed_list(embed_list):
1232
+ global textual_inversion_file_list
1233
+ tupled_list = []
1234
+ for file in embed_list:
1235
+ token = textual_inversion_dict.get(Path(file).name, [Path(file).stem.replace(",",""), False])[0]
1236
+ tupled_list.append((token, file))
1237
+ textual_inversion_file_token_list.append(token)
1238
+ return tupled_list
1239
+
1240
+
1241
+ def set_textual_inversion_prompt(textual_inversion_gui, prompt_gui, neg_prompt_gui, prompt_syntax_gui):
1242
+ ti_tags = list(textual_inversion_dict.values()) + textual_inversion_file_token_list
1243
+ tags = prompt_gui.split(",") if prompt_gui else []
1244
+ prompts = []
1245
+ for tag in tags:
1246
+ tag = str(tag).strip()
1247
+ if tag and not tag in ti_tags:
1248
+ prompts.append(tag)
1249
+ ntags = neg_prompt_gui.split(",") if neg_prompt_gui else []
1250
+ neg_prompts = []
1251
+ for tag in ntags:
1252
+ tag = str(tag).strip()
1253
+ if tag and not tag in ti_tags:
1254
+ neg_prompts.append(tag)
1255
+ ti_prompts = []
1256
+ ti_neg_prompts = []
1257
+ for ti in textual_inversion_gui:
1258
+ tokens = textual_inversion_dict.get(Path(ti).name, [Path(ti).stem.replace(",",""), False])
1259
+ is_positive = tokens[1] == True or "positive" in Path(ti).parent.name
1260
+ if is_positive: # positive prompt
1261
+ ti_prompts.append(tokens[0])
1262
+ else: # negative prompt (default)
1263
+ ti_neg_prompts.append(tokens[0])
1264
+ empty = [""]
1265
+ prompt = ", ".join(prompts + ti_prompts + empty)
1266
+ neg_prompt = ", ".join(neg_prompts + ti_neg_prompts + empty)
1267
+ return gr.update(value=prompt), gr.update(value=neg_prompt),
1268
+
1269
+
1270
+ def get_model_pipeline(repo_id: str):
1271
+ from huggingface_hub import HfApi
1272
+ api = HfApi(token=HF_TOKEN)
1273
+ default = "StableDiffusionPipeline"
1274
+ try:
1275
+ if not is_repo_name(repo_id): return default
1276
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
1277
+ except Exception:
1278
+ return default
1279
+ if model.private or model.gated: return default
1280
+ tags = model.tags
1281
+ if not 'diffusers' in tags: return default
1282
+ if 'diffusers:FluxPipeline' in tags:
1283
+ return "FluxPipeline"
1284
+ if 'diffusers:StableDiffusionXLPipeline' in tags:
1285
+ return "StableDiffusionXLPipeline"
1286
+ elif 'diffusers:StableDiffusionPipeline' in tags:
1287
+ return "StableDiffusionPipeline"
1288
+ else:
1289
+ return default
1290
+