John6666 commited on
Commit
5e404f6
1 Parent(s): 9ceef7d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +26 -10
  2. mod.py +29 -23
  3. requirements.txt +3 -4
app.py CHANGED
@@ -11,6 +11,7 @@ import copy
11
  import random
12
  import time
13
 
 
14
  from mod import (models, clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists,
15
  description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
16
  get_trigger_word, enhance_prompt, deselect_lora, num_cns, set_control_union_image,
@@ -21,11 +22,14 @@ from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_
21
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
22
  from tagger.fl2flux import predict_tags_fl2_flux
23
 
 
 
 
24
  # Initialize the base model
25
  base_model = models[0]
26
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
27
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
28
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
29
  controlnet_union = None
30
  controlnet = None
31
  last_model = models[0]
@@ -33,14 +37,13 @@ last_cn_on = False
33
 
34
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
35
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
36
- def change_base_model(repo_id: str, cn_on: bool):
37
  global pipe
38
  global controlnet_union
39
  global controlnet
40
  global last_model
41
  global last_cn_on
42
- dtype = torch.bfloat16
43
- #dtype = torch.float8_e4m3fn
44
  try:
45
  if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
46
  if cn_on:
@@ -65,7 +68,7 @@ def change_base_model(repo_id: str, cn_on: bool):
65
  print(f"Model loaded: {repo_id}")
66
  except Exception as e:
67
  print(f"Model load Error: {e}")
68
- raise gr.Error(f"Model load Error: {e}")
69
  return gr.update(visible=True)
70
 
71
  change_base_model.zerogpu = True
@@ -156,7 +159,7 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
156
  ).images[0]
157
  except Exception as e:
158
  print(e)
159
- raise gr.Error(f"Inference Error: {e}")
160
  return image
161
 
162
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
@@ -210,7 +213,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
210
  if controlnet is not None: controlnet.to("cpu")
211
  if controlnet_union is not None: controlnet_union.to("cpu")
212
  clear_cache()
213
- return image, seed
214
 
215
  def get_huggingface_safetensors(link):
216
  split_link = link.split("/")
@@ -299,6 +302,10 @@ css = '''
299
  .card_internal{display: flex;height: 100px;margin-top: .5em}
300
  .card_internal img{margin-right: 1em}
301
  .styler{--form-gap-width: 0px !important}
 
 
 
 
302
  #model-info {text-align: center; !important}
303
  '''
304
  with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache=(60, 3600)) as app:
@@ -444,7 +451,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache
444
  fn=change_base_model,
445
  inputs=[model_name, cn_on],
446
  outputs=[result],
447
- queue=False,
448
  show_api=False,
449
  trigger_mode="once",
450
  ).success(
@@ -457,7 +464,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache
457
  )
458
 
459
  deselect_lora_button.click(deselect_lora, None, [prompt, selected_info, selected_index, width, height], queue=False, show_api=False)
460
- gr.on(
461
  triggers=[model_name.change, cn_on.change],
462
  fn=change_base_model,
463
  inputs=[model_name, cn_on],
@@ -465,7 +472,16 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache
465
  queue=True,
466
  show_api=False,
467
  trigger_mode="once",
468
- ).then(get_t2i_model_info, [model_name], [model_info], queue=False, show_api=False)
 
 
 
 
 
 
 
 
 
469
  prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
470
 
471
  gr.on(
 
11
  import random
12
  import time
13
 
14
+
15
  from mod import (models, clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists,
16
  description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
17
  get_trigger_word, enhance_prompt, deselect_lora, num_cns, set_control_union_image,
 
22
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
23
  from tagger.fl2flux import predict_tags_fl2_flux
24
 
25
+
26
+ dtype = torch.bfloat16
27
+ #dtype = torch.float8_e4m3fn
28
  # Initialize the base model
29
  base_model = models[0]
30
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
31
  #controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
32
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype)
33
  controlnet_union = None
34
  controlnet = None
35
  last_model = models[0]
 
37
 
38
  # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
39
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
40
+ def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm=True)):
41
  global pipe
42
  global controlnet_union
43
  global controlnet
44
  global last_model
45
  global last_cn_on
46
+ global dtype
 
47
  try:
48
  if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
49
  if cn_on:
 
68
  print(f"Model loaded: {repo_id}")
69
  except Exception as e:
70
  print(f"Model load Error: {e}")
71
+ raise gr.Error(f"Model load Error: {e}") from e
72
  return gr.update(visible=True)
73
 
74
  change_base_model.zerogpu = True
 
159
  ).images[0]
160
  except Exception as e:
161
  print(e)
162
+ raise gr.Error(f"Inference Error {e}") from e
163
  return image
164
 
165
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
 
213
  if controlnet is not None: controlnet.to("cpu")
214
  if controlnet_union is not None: controlnet_union.to("cpu")
215
  clear_cache()
216
+ return image, seed
217
 
218
  def get_huggingface_safetensors(link):
219
  split_link = link.split("/")
 
302
  .card_internal{display: flex;height: 100px;margin-top: .5em}
303
  .card_internal img{margin-right: 1em}
304
  .styler{--form-gap-width: 0px !important}
305
+ #progress{height:30px}
306
+ #progress .generating{display:none}
307
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
308
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
309
  #model-info {text-align: center; !important}
310
  '''
311
  with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache=(60, 3600)) as app:
 
451
  fn=change_base_model,
452
  inputs=[model_name, cn_on],
453
  outputs=[result],
454
+ queue=True,
455
  show_api=False,
456
  trigger_mode="once",
457
  ).success(
 
464
  )
465
 
466
  deselect_lora_button.click(deselect_lora, None, [prompt, selected_info, selected_index, width, height], queue=False, show_api=False)
467
+ """gr.on(
468
  triggers=[model_name.change, cn_on.change],
469
  fn=change_base_model,
470
  inputs=[model_name, cn_on],
 
472
  queue=True,
473
  show_api=False,
474
  trigger_mode="once",
475
+ ).then(get_t2i_model_info, [model_name], [model_info], queue=True, show_api=False)"""
476
+ gr.on(
477
+ triggers=[model_name.change, cn_on.change],
478
+ fn=get_t2i_model_info,
479
+ inputs=[model_name],
480
+ outputs=[model_info],
481
+ queue=False,
482
+ show_api=False,
483
+ trigger_mode="once",
484
+ ).then(change_base_model, [model_name, cn_on], [result], queue=True, show_api=False)
485
  prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
486
 
487
  gr.on(
mod.py CHANGED
@@ -69,7 +69,7 @@ def is_repo_exists(repo_id):
69
  if api.repo_exists(repo_id=repo_id): return True
70
  else: return False
71
  except Exception as e:
72
- print(f"Error: Failed to connect {repo_id}. ")
73
  print(e)
74
  return True # for safe
75
 
@@ -82,6 +82,7 @@ def clear_cache():
82
  gc.collect()
83
  except Exception as e:
84
  print(e)
 
85
 
86
 
87
  def deselect_lora():
@@ -108,6 +109,7 @@ def get_repo_safetensors(repo_id: str):
108
  except Exception as e:
109
  print(f"Error: Failed to get {repo_id}'s info.")
110
  print(e)
 
111
  return gr.update(choices=[])
112
  files = [f for f in files if f.endswith(".safetensors")]
113
  if len(files) == 0: return gr.update(value="", choices=[])
@@ -290,28 +292,32 @@ def get_trigger_word(lorajson: list[dict]):
290
  # https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora
291
  # https://github.com/huggingface/diffusers/issues/4919
292
  def fuse_loras(pipe, lorajson: list[dict]):
293
- if not lorajson or not isinstance(lorajson, list): return
294
- a_list = []
295
- w_list = []
296
- for d in lorajson:
297
- if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue
298
- k = d["name"]
299
- if is_repo_name(k) and is_repo_exists(k):
300
- a_name = Path(k).stem
301
- pipe.load_lora_weights(k, weight_name=d["filename"], adapter_name = a_name)
302
- elif not Path(k).exists():
303
- print(f"LoRA not found: {k}")
304
- continue
305
- else:
306
- w_name = Path(k).name
307
- a_name = Path(k).stem
308
- pipe.load_lora_weights(k, weight_name = w_name, adapter_name = a_name)
309
- a_list.append(a_name)
310
- w_list.append(d["scale"])
311
- if not a_list: return
312
- pipe.set_adapters(a_list, adapter_weights=w_list)
313
- pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
314
- #pipe.unload_lora_weights()
 
 
 
 
315
 
316
 
317
  def description_ui():
 
69
  if api.repo_exists(repo_id=repo_id): return True
70
  else: return False
71
  except Exception as e:
72
+ print(f"Error: Failed to connect {repo_id}.")
73
  print(e)
74
  return True # for safe
75
 
 
82
  gc.collect()
83
  except Exception as e:
84
  print(e)
85
+ raise Exception(f"Cache clearing error: {e}") from e
86
 
87
 
88
  def deselect_lora():
 
109
  except Exception as e:
110
  print(f"Error: Failed to get {repo_id}'s info.")
111
  print(e)
112
+ gr.Warning(f"Error: Failed to get {repo_id}'s info.")
113
  return gr.update(choices=[])
114
  files = [f for f in files if f.endswith(".safetensors")]
115
  if len(files) == 0: return gr.update(value="", choices=[])
 
292
  # https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora
293
  # https://github.com/huggingface/diffusers/issues/4919
294
  def fuse_loras(pipe, lorajson: list[dict]):
295
+ try:
296
+ if not lorajson or not isinstance(lorajson, list): return
297
+ a_list = []
298
+ w_list = []
299
+ for d in lorajson:
300
+ if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue
301
+ k = d["name"]
302
+ if is_repo_name(k) and is_repo_exists(k):
303
+ a_name = Path(k).stem
304
+ pipe.load_lora_weights(k, weight_name=d["filename"], adapter_name = a_name)
305
+ elif not Path(k).exists():
306
+ print(f"LoRA not found: {k}")
307
+ continue
308
+ else:
309
+ w_name = Path(k).name
310
+ a_name = Path(k).stem
311
+ pipe.load_lora_weights(k, weight_name = w_name, adapter_name = a_name)
312
+ a_list.append(a_name)
313
+ w_list.append(d["scale"])
314
+ if not a_list: return
315
+ pipe.set_adapters(a_list, adapter_weights=w_list)
316
+ pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
317
+ #pipe.unload_lora_weights()
318
+ except Exception as e:
319
+ print(f"External LoRA Error: {e}")
320
+ raise Exception(f"External LoRA Error: {e}") from e
321
 
322
 
323
  def description_ui():
requirements.txt CHANGED
@@ -1,12 +1,11 @@
1
  spaces
 
2
  git+https://github.com/huggingface/diffusers
3
- torch==2.2.0
4
- torchvision
5
- huggingface_hub
6
- accelerate
7
  transformers
8
  peft
9
  sentencepiece
 
 
10
  timm
11
  einops
12
  controlnet_aux
 
1
  spaces
2
+ torch
3
  git+https://github.com/huggingface/diffusers
 
 
 
 
4
  transformers
5
  peft
6
  sentencepiece
7
+ torchvision
8
+ huggingface_hub
9
  timm
10
  einops
11
  controlnet_aux