r3gm commited on
Commit
b647d6d
1 Parent(s): d6eb980

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -112
app.py CHANGED
@@ -3,8 +3,10 @@ import os
3
  from stablepy import Model_Diffusers
4
  from stablepy.diffusers_vanilla.model import scheduler_names
5
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
 
6
  import torch
7
  import re
 
8
  import shutil
9
  import random
10
  from stablepy import (
@@ -123,7 +125,7 @@ task_stablepy = {
123
  'tile ControlNet': 'tile',
124
  }
125
 
126
- task_model_list = list(task_stablepy.keys())
127
 
128
 
129
  def download_things(directory, url, hf_token="", civitai_api_key=""):
@@ -329,6 +331,7 @@ upscaler_dict_gui = {
329
 
330
  upscaler_keys = list(upscaler_dict_gui.keys())
331
 
 
332
  def extract_parameters(input_string):
333
  parameters = {}
334
  input_string = input_string.replace("\n", "")
@@ -374,10 +377,6 @@ def extract_parameters(input_string):
374
  #######################
375
  import spaces
376
  import gradio as gr
377
- from PIL import Image
378
- import IPython.display
379
- import time, json
380
- from IPython.utils import capture
381
  import logging
382
  logging.getLogger("diffusers").setLevel(logging.ERROR)
383
  import diffusers
@@ -387,8 +386,14 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffuse
387
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
388
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
389
  from stablepy import logger
 
390
  logger.setLevel(logging.DEBUG)
391
 
 
 
 
 
 
392
 
393
  def info_html(json_data, title, subtitle):
394
  return f"""
@@ -402,6 +407,19 @@ def info_html(json_data, title, subtitle):
402
  """
403
 
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  class GuiSD:
406
  def __init__(self, stream=True):
407
  self.model = None
@@ -421,23 +439,21 @@ class GuiSD:
421
  yield f"Loading model: {model_name}"
422
 
423
  vae_model = vae_model if vae_model != "None" else None
 
424
 
425
- if model_name in model_list:
426
- model_is_xl = "xl" in model_name.lower()
427
- sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
428
- model_type = "SDXL" if model_is_xl else "SD 1.5"
429
- incompatible_vae = (model_is_xl and vae_model and not sdxl_in_vae) or (not model_is_xl and sdxl_in_vae)
430
-
431
- if incompatible_vae:
432
- vae_model = None
433
 
434
  self.model.device = torch.device("cpu")
 
435
 
436
  self.model.load_pipe(
437
  model_name,
438
  task_name=task_stablepy[task],
439
- vae_model=vae_model if vae_model != "None" else None,
440
- type_model_precision=torch.float16 if "flux" not in model_name.lower() else torch.bfloat16,
441
  retain_task_model_in_cache=False,
442
  )
443
  yield f"Model loaded: {model_name}"
@@ -555,30 +571,7 @@ class GuiSD:
555
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
556
  msg_lora = []
557
 
558
- if model_name in model_list:
559
- model_is_xl = "xl" in model_name.lower()
560
- sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
561
- model_type = "SDXL" if model_is_xl else "SD 1.5"
562
- incompatible_vae = (model_is_xl and vae_model and not sdxl_in_vae) or (not model_is_xl and sdxl_in_vae)
563
-
564
- if incompatible_vae:
565
- msg_inc_vae = (
566
- f"The selected VAE is for a { 'SD 1.5' if model_is_xl else 'SDXL' } model, but you"
567
- f" are using a { model_type } model. The default VAE "
568
- "will be used."
569
- )
570
- gr.Info(msg_inc_vae)
571
- vae_msg = msg_inc_vae
572
- vae_model = None
573
-
574
- for la in loras_list:
575
- if la is not None and la != "None" and la in lora_model_list:
576
- print(la)
577
- lora_type = ("animetarot" in la.lower() or "Hyper-SD15-8steps".lower() in la.lower())
578
- if (model_is_xl and lora_type) or (not model_is_xl and not lora_type):
579
- msg_inc_lora = f"The LoRA {la} is for { 'SD 1.5' if model_is_xl else 'SDXL' }, but you are using { model_type }."
580
- gr.Info(msg_inc_lora)
581
- msg_lora.append(msg_inc_lora)
582
 
583
  task = task_stablepy[task]
584
 
@@ -602,18 +595,6 @@ class GuiSD:
602
  params_ip_mode.append(modeip)
603
  params_ip_scale.append(scaleip)
604
 
605
- model_precision = torch.float16 if "flux" not in model_name.lower() else torch.bfloat16
606
-
607
- # First load
608
- if not self.model:
609
- print("Loading model...")
610
- self.model = Model_Diffusers(
611
- base_model_id=model_name,
612
- task_name=task,
613
- vae_model=vae_model if vae_model != "None" else None,
614
- type_model_precision=model_precision,
615
- retain_task_model_in_cache=retain_task_cache_gui,
616
- )
617
  self.model.stream_config(concurrency=5, latent_resize_by=1, vae_decoding=False)
618
 
619
  if task != "txt2img" and not image_control:
@@ -637,45 +618,32 @@ class GuiSD:
637
 
638
  logging.getLogger("ultralytics").setLevel(logging.INFO if adetailer_verbose else logging.ERROR)
639
 
640
- print("Config model:", model_name, vae_model, loras_list)
641
-
642
- self.model.load_pipe(
643
- model_name,
644
- task_name=task,
645
- vae_model=vae_model if vae_model != "None" else None,
646
- type_model_precision=model_precision,
647
- retain_task_model_in_cache=retain_task_cache_gui,
648
- )
649
-
650
- if textual_inversion and self.model.class_name == "StableDiffusionXLPipeline":
651
- print("No Textual inversion for SDXL")
652
-
653
  adetailer_params_A = {
654
- "face_detector_ad" : face_detector_ad_a,
655
- "person_detector_ad" : person_detector_ad_a,
656
- "hand_detector_ad" : hand_detector_ad_a,
657
  "prompt": prompt_ad_a,
658
- "negative_prompt" : negative_prompt_ad_a,
659
- "strength" : strength_ad_a,
660
  # "image_list_task" : None,
661
- "mask_dilation" : mask_dilation_a,
662
- "mask_blur" : mask_blur_a,
663
- "mask_padding" : mask_padding_a,
664
- "inpaint_only" : adetailer_inpaint_only,
665
- "sampler" : adetailer_sampler,
666
  }
667
 
668
  adetailer_params_B = {
669
- "face_detector_ad" : face_detector_ad_b,
670
- "person_detector_ad" : person_detector_ad_b,
671
- "hand_detector_ad" : hand_detector_ad_b,
672
  "prompt": prompt_ad_b,
673
- "negative_prompt" : negative_prompt_ad_b,
674
- "strength" : strength_ad_b,
675
  # "image_list_task" : None,
676
- "mask_dilation" : mask_dilation_b,
677
- "mask_blur" : mask_blur_b,
678
- "mask_padding" : mask_padding_b,
679
  }
680
  pipe_params = {
681
  "prompt": prompt,
@@ -759,7 +727,7 @@ class GuiSD:
759
  if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
760
  self.model.pipe.transformer.to(self.model.device)
761
  print("transformer to cuda")
762
-
763
  info_state = "PROCESSING "
764
  for img, seed, image_path, metadata in self.model(**pipe_params):
765
  info_state += ">"
@@ -786,42 +754,53 @@ class GuiSD:
786
 
787
  sd_gen = GuiSD()
788
 
789
- CSS ="""
790
  .contain { display: flex; flex-direction: column; }
791
  #component-0 { height: 100%; }
792
  #gallery { flex-grow: 1; }
793
  """
794
- sdxl_task = [k for k, v in task_stablepy.items() if v in SDXL_TASKS ]
795
- sd_task = [k for k, v in task_stablepy.items() if v in SD15_TASKS ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
  def update_task_options(model_name, task_name):
797
- if model_name in model_list:
798
- if "xl" in model_name.lower() or "pony" in model_name.lower():
799
- new_choices = sdxl_task
800
- else:
801
- new_choices = sd_task
802
 
803
- if task_name not in new_choices:
804
- task_name = "txt2img"
 
 
805
 
806
- return gr.update(value=task_name, choices=new_choices)
807
- else:
808
- return gr.update(value=task_name, choices=task_model_list)
809
 
810
  POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
811
 
 
 
 
 
 
812
  with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
813
  gr.Markdown("# 🧩 DiffuseCraft")
814
- gr.Markdown(
815
- f"""
816
- ### This demo uses [diffusers](https://github.com/huggingface/diffusers) to perform different tasks in image generation.
817
- """
818
- )
819
  with gr.Tab("Generation"):
820
  with gr.Row():
821
 
822
  with gr.Column(scale=2):
823
 
824
- task_gui = gr.Dropdown(label="Task", choices=sdxl_task, value=task_model_list[0])
825
  model_name_gui = gr.Dropdown(label="Model", choices=model_list, value=model_list[0], allow_custom_value=True)
826
  prompt_gui = gr.Textbox(lines=5, placeholder="Enter prompt", label="Prompt")
827
  neg_prompt_gui = gr.Textbox(lines=3, placeholder="Enter Neg prompt", label="Negative prompt")
@@ -1140,7 +1119,7 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1140
  gr.Markdown(
1141
  """### The following examples perform specific tasks:
1142
  1. Generation with SDXL and upscale
1143
- 2. Generation with SDXL
1144
  3. ControlNet Canny SDXL
1145
  4. Optical pattern (Optical illusion) SDXL
1146
  5. Convert an image to a coloring drawing
@@ -1195,11 +1174,11 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1195
  "Nearest",
1196
  ],
1197
  [
1198
- "score_9, score_8_up, score_8, medium breasts, cute, eyelashes , princess Zelda OOT, cute small face, long hair, crown braid, hairclip, pointy ears, soft curvy body, solo, looking at viewer, smile, blush, white dress, medium body, (((holding the Master Sword))), standing, deep forest in the background",
1199
- "score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white,",
1200
  1,
1201
- 30,
1202
- 5.,
1203
  True,
1204
  -1,
1205
  "None",
@@ -1207,15 +1186,15 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1207
  "None",
1208
  1.0,
1209
  "None",
1210
- 1.0,
1211
  "None",
1212
  1.0,
1213
  "None",
1214
  1.0,
1215
- "DPM++ 2M Karras",
1216
  1024,
1217
  1024,
1218
- "kitty7779/ponyDiffusionV6XL",
1219
  None, # vae
1220
  "txt2img",
1221
  None, # img conttol
@@ -1235,7 +1214,7 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1235
  1., # cn end
1236
  False, # ti
1237
  "Classic",
1238
- "Nearest",
1239
  ],
1240
  [
1241
  "((masterpiece)), best quality, blonde disco girl, detailed face, realistic face, realistic hair, dynamic pose, pink pvc, intergalactic disco background, pastel lights, dynamic contrast, airbrush, fine detail, 70s vibe, midriff ",
@@ -1648,4 +1627,4 @@ app.launch(
1648
  show_error=True,
1649
  debug=True,
1650
  allowed_paths=["./images/"],
1651
- )
 
3
  from stablepy import Model_Diffusers
4
  from stablepy.diffusers_vanilla.model import scheduler_names
5
  from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
6
+ from stablepy.diffusers_vanilla.constants import FLUX_CN_UNION_MODES
7
  import torch
8
  import re
9
+ from huggingface_hub import HfApi
10
  import shutil
11
  import random
12
  from stablepy import (
 
125
  'tile ControlNet': 'tile',
126
  }
127
 
128
+ TASK_MODEL_LIST = list(task_stablepy.keys())
129
 
130
 
131
  def download_things(directory, url, hf_token="", civitai_api_key=""):
 
331
 
332
  upscaler_keys = list(upscaler_dict_gui.keys())
333
 
334
+
335
  def extract_parameters(input_string):
336
  parameters = {}
337
  input_string = input_string.replace("\n", "")
 
377
  #######################
378
  import spaces
379
  import gradio as gr
 
 
 
 
380
  import logging
381
  logging.getLogger("diffusers").setLevel(logging.ERROR)
382
  import diffusers
 
386
  warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
387
  warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
388
  from stablepy import logger
389
+
390
  logger.setLevel(logging.DEBUG)
391
 
392
+ msg_inc_vae = (
393
+ "Use the right VAE for your model to maintain image quality. The wrong"
394
+ " VAE can lead to poor results, like blurriness in the generated images."
395
+ )
396
+
397
 
398
  def info_html(json_data, title, subtitle):
399
  return f"""
 
407
  """
408
 
409
 
410
+ def get_model_type(repo_id: str):
411
+ api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
412
+ default = "SD 1.5"
413
+ try:
414
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
415
+ tags = model.tags
416
+ for tag in tags:
417
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
418
+ except Exception:
419
+ return default
420
+ return default
421
+
422
+
423
  class GuiSD:
424
  def __init__(self, stream=True):
425
  self.model = None
 
439
  yield f"Loading model: {model_name}"
440
 
441
  vae_model = vae_model if vae_model != "None" else None
442
+ model_type = get_model_type(model_name)
443
 
444
+ if vae_model:
445
+ vae_type = "SXDL" if "sdxl" in vae_model.lower() else "SD 1.5"
446
+ if model_type != vae_type:
447
+ gr.Info(msg_inc_vae)
 
 
 
 
448
 
449
  self.model.device = torch.device("cpu")
450
+ dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
451
 
452
  self.model.load_pipe(
453
  model_name,
454
  task_name=task_stablepy[task],
455
+ vae_model=vae_model,
456
+ type_model_precision=dtype_model,
457
  retain_task_model_in_cache=False,
458
  )
459
  yield f"Model loaded: {model_name}"
 
571
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
572
  msg_lora = []
573
 
574
+ print("Config model:", model_name, vae_model, loras_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
  task = task_stablepy[task]
577
 
 
595
  params_ip_mode.append(modeip)
596
  params_ip_scale.append(scaleip)
597
 
 
 
 
 
 
 
 
 
 
 
 
 
598
  self.model.stream_config(concurrency=5, latent_resize_by=1, vae_decoding=False)
599
 
600
  if task != "txt2img" and not image_control:
 
618
 
619
  logging.getLogger("ultralytics").setLevel(logging.INFO if adetailer_verbose else logging.ERROR)
620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  adetailer_params_A = {
622
+ "face_detector_ad": face_detector_ad_a,
623
+ "person_detector_ad": person_detector_ad_a,
624
+ "hand_detector_ad": hand_detector_ad_a,
625
  "prompt": prompt_ad_a,
626
+ "negative_prompt": negative_prompt_ad_a,
627
+ "strength": strength_ad_a,
628
  # "image_list_task" : None,
629
+ "mask_dilation": mask_dilation_a,
630
+ "mask_blur": mask_blur_a,
631
+ "mask_padding": mask_padding_a,
632
+ "inpaint_only": adetailer_inpaint_only,
633
+ "sampler": adetailer_sampler,
634
  }
635
 
636
  adetailer_params_B = {
637
+ "face_detector_ad": face_detector_ad_b,
638
+ "person_detector_ad": person_detector_ad_b,
639
+ "hand_detector_ad": hand_detector_ad_b,
640
  "prompt": prompt_ad_b,
641
+ "negative_prompt": negative_prompt_ad_b,
642
+ "strength": strength_ad_b,
643
  # "image_list_task" : None,
644
+ "mask_dilation": mask_dilation_b,
645
+ "mask_blur": mask_blur_b,
646
+ "mask_padding": mask_padding_b,
647
  }
648
  pipe_params = {
649
  "prompt": prompt,
 
727
  if hasattr(self.model.pipe, "transformer") and loras_list != ["None"] * 5:
728
  self.model.pipe.transformer.to(self.model.device)
729
  print("transformer to cuda")
730
+
731
  info_state = "PROCESSING "
732
  for img, seed, image_path, metadata in self.model(**pipe_params):
733
  info_state += ">"
 
754
 
755
  sd_gen = GuiSD()
756
 
757
+ CSS = """
758
  .contain { display: flex; flex-direction: column; }
759
  #component-0 { height: 100%; }
760
  #gallery { flex-grow: 1; }
761
  """
762
+ SDXL_TASK = [k for k, v in task_stablepy.items() if v in SDXL_TASKS ]
763
+ SD_TASK = [k for k, v in task_stablepy.items() if v in SD15_TASKS ]
764
+ FLUX_TASK = list(task_stablepy.keys())[:3] + [k for k, v in task_stablepy.items() if v in FLUX_CN_UNION_MODES.keys() ]
765
+
766
+ MODEL_TYPE_TASK = {
767
+ "SD 1.5": SD_TASK,
768
+ "SDXL": SDXL_TASK,
769
+ "FLUX": FLUX_TASK,
770
+ }
771
+
772
+ MODEL_TYPE_CLASS = {
773
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
774
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
775
+ "diffusers:FluxPipeline": "FLUX",
776
+ }
777
+
778
+
779
  def update_task_options(model_name, task_name):
780
+ new_choices = MODEL_TYPE_TASK[get_model_type(model_name)]
 
 
 
 
781
 
782
+ if task_name not in new_choices:
783
+ task_name = "txt2img"
784
+
785
+ return gr.update(value=task_name, choices=new_choices)
786
 
 
 
 
787
 
788
  POST_PROCESSING_SAMPLER = ["Use same sampler"] + scheduler_names[:-2]
789
 
790
+ SUBTITLE_GUI = (
791
+ "### This demo uses [diffusers](https://github.com/huggingface/diffusers)"
792
+ " to perform different tasks in image generation."
793
+ )
794
+
795
  with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
796
  gr.Markdown("# 🧩 DiffuseCraft")
797
+ gr.Markdown(SUBTITLE_GUI)
 
 
 
 
798
  with gr.Tab("Generation"):
799
  with gr.Row():
800
 
801
  with gr.Column(scale=2):
802
 
803
+ task_gui = gr.Dropdown(label="Task", choices=SDXL_TASK, value=TASK_MODEL_LIST[0])
804
  model_name_gui = gr.Dropdown(label="Model", choices=model_list, value=model_list[0], allow_custom_value=True)
805
  prompt_gui = gr.Textbox(lines=5, placeholder="Enter prompt", label="Prompt")
806
  neg_prompt_gui = gr.Textbox(lines=3, placeholder="Enter Neg prompt", label="Negative prompt")
 
1119
  gr.Markdown(
1120
  """### The following examples perform specific tasks:
1121
  1. Generation with SDXL and upscale
1122
+ 2. Generation with FLUX dev
1123
  3. ControlNet Canny SDXL
1124
  4. Optical pattern (Optical illusion) SDXL
1125
  5. Convert an image to a coloring drawing
 
1174
  "Nearest",
1175
  ],
1176
  [
1177
+ "a tiny astronaut hatching from an egg on the moon",
1178
+ "",
1179
  1,
1180
+ 28,
1181
+ 3.5,
1182
  True,
1183
  -1,
1184
  "None",
 
1186
  "None",
1187
  1.0,
1188
  "None",
1189
+ 1.0,
1190
  "None",
1191
  1.0,
1192
  "None",
1193
  1.0,
1194
+ "Euler a",
1195
  1024,
1196
  1024,
1197
+ "black-forest-labs/FLUX.1-dev",
1198
  None, # vae
1199
  "txt2img",
1200
  None, # img conttol
 
1214
  1., # cn end
1215
  False, # ti
1216
  "Classic",
1217
+ None,
1218
  ],
1219
  [
1220
  "((masterpiece)), best quality, blonde disco girl, detailed face, realistic face, realistic hair, dynamic pose, pink pvc, intergalactic disco background, pastel lights, dynamic contrast, airbrush, fine detail, 70s vibe, midriff ",
 
1627
  show_error=True,
1628
  debug=True,
1629
  allowed_paths=["./images/"],
1630
+ )