John6666 commited on
Commit
f0c3651
1 Parent(s): 5764354

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +41 -29
  2. mod.py +17 -1
app.py CHANGED
@@ -5,8 +5,9 @@ import torch
5
  from PIL import Image
6
  import spaces
7
  from diffusers import DiffusionPipeline
8
- from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
9
- from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
 
10
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
11
  import copy
12
  import random
@@ -14,7 +15,7 @@ import time
14
 
15
  from mod import (models, clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists,
16
  description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
17
- get_trigger_word, enhance_prompt, num_cns, set_control_union_image,
18
  get_control_union_mode, set_control_union_mode, get_control_params)
19
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
20
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
@@ -26,6 +27,8 @@ from tagger.fl2flux import predict_tags_fl2_flux
26
  base_model = models[0]
27
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
28
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
 
 
29
  last_model = models[0]
30
  last_cn_on = False
31
 
@@ -33,6 +36,8 @@ last_cn_on = False
33
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
34
  def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm=True)):
35
  global pipe
 
 
36
  global last_model
37
  global last_cn_on
38
  try:
@@ -115,31 +120,35 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
115
  with calculateDuration("Generating image"):
116
  # Generate image
117
  modes, images, scales = get_control_params()
118
- if not cn_on or len(modes) == 0:
119
- progress(0, desc="Start Inference.")
120
- image = pipe(
121
- prompt=prompt_mash,
122
- num_inference_steps=steps,
123
- guidance_scale=cfg_scale,
124
- width=width,
125
- height=height,
126
- generator=generator,
127
- joint_attention_kwargs={"scale": lora_scale},
128
- ).images[0]
129
- else:
130
- progress(0, desc="Start Inference with ControlNet.")
131
- image = pipe(
132
- prompt=prompt_mash,
133
- control_image=images,
134
- control_mode=modes,
135
- num_inference_steps=steps,
136
- guidance_scale=cfg_scale,
137
- width=width,
138
- height=height,
139
- controlnet_conditioning_scale=scales,
140
- generator=generator,
141
- joint_attention_kwargs={"scale": lora_scale},
142
- ).images[0]
 
 
 
 
143
  return image
144
 
145
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
@@ -320,6 +329,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
320
  gr.Markdown("[Check the list of FLUX LoRas](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
321
  custom_lora_info = gr.HTML(visible=False)
322
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
 
323
  with gr.Column(scale=4):
324
  result = gr.Image(label="Generated Image", format="png", show_share_button=False)
325
 
@@ -422,6 +432,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
422
  outputs=[result, seed]
423
  )
424
 
 
425
  gr.on(
426
  triggers=[model_name.change, cn_on.change],
427
  fn=change_base_model,
@@ -443,6 +454,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
443
  lora_search_civitai_result.change(select_civitai_lora, [lora_search_civitai_result], [lora_download_url, lora_search_civitai_desc], scroll_to_output=True, queue=False, show_api=False)
444
 
445
  for i, l in enumerate(lora_repo):
 
446
  gr.on(
447
  triggers=[lora_download[i].click],
448
  fn=download_my_lora,
@@ -463,7 +475,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
463
  ).success(get_repo_safetensors, [lora_repo[i]], [lora_weights[i]], queue=False, show_api=False
464
  ).success(apply_lora_prompt, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
465
  ).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
466
-
467
  for i, m in enumerate(cn_mode):
468
  gr.on(
469
  triggers=[cn_mode[i].change, cn_scale[i].change],
 
5
  from PIL import Image
6
  import spaces
7
  from diffusers import DiffusionPipeline
8
+ #from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
9
+ #from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
10
+ from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel
11
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
12
  import copy
13
  import random
 
15
 
16
  from mod import (models, clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists,
17
  description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
18
+ get_trigger_word, enhance_prompt, deselect_lora, num_cns, set_control_union_image,
19
  get_control_union_mode, set_control_union_mode, get_control_params)
20
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
21
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
 
27
  base_model = models[0]
28
  controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
29
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
30
+ controlnet_union = None
31
+ controlnet = None
32
  last_model = models[0]
33
  last_cn_on = False
34
 
 
36
  # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
37
  def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm=True)):
38
  global pipe
39
+ global controlnet_union
40
+ global controlnet
41
  global last_model
42
  global last_cn_on
43
  try:
 
120
  with calculateDuration("Generating image"):
121
  # Generate image
122
  modes, images, scales = get_control_params()
123
+ try:
124
+ if not cn_on or len(modes) == 0:
125
+ progress(0, desc="Start Inference.")
126
+ image = pipe(
127
+ prompt=prompt_mash,
128
+ num_inference_steps=steps,
129
+ guidance_scale=cfg_scale,
130
+ width=width,
131
+ height=height,
132
+ generator=generator,
133
+ joint_attention_kwargs={"scale": lora_scale},
134
+ ).images[0]
135
+ else:
136
+ progress(0, desc="Start Inference with ControlNet.")
137
+ image = pipe(
138
+ prompt=prompt_mash,
139
+ control_image=images,
140
+ control_mode=modes,
141
+ num_inference_steps=steps,
142
+ guidance_scale=cfg_scale,
143
+ width=width,
144
+ height=height,
145
+ controlnet_conditioning_scale=scales,
146
+ generator=generator,
147
+ joint_attention_kwargs={"scale": lora_scale},
148
+ ).images[0]
149
+ except Exception as e:
150
+ print(e)
151
+ raise Exception(f"Inference Error: {e}")
152
  return image
153
 
154
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
 
329
  gr.Markdown("[Check the list of FLUX LoRas](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
330
  custom_lora_info = gr.HTML(visible=False)
331
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
332
+ deselect_lora_button = gr.Button("Deselect LoRA", variant="secondary")
333
  with gr.Column(scale=4):
334
  result = gr.Image(label="Generated Image", format="png", show_share_button=False)
335
 
 
432
  outputs=[result, seed]
433
  )
434
 
435
+ deselect_lora_button.click(deselect_lora, None, [prompt, selected_info, selected_index, width, height])
436
  gr.on(
437
  triggers=[model_name.change, cn_on.change],
438
  fn=change_base_model,
 
454
  lora_search_civitai_result.change(select_civitai_lora, [lora_search_civitai_result], [lora_download_url, lora_search_civitai_desc], scroll_to_output=True, queue=False, show_api=False)
455
 
456
  for i, l in enumerate(lora_repo):
457
+ deselect_lora_button.click(lambda: ("", 1.0), None, [lora_repo[i], lora_wt[i]])
458
  gr.on(
459
  triggers=[lora_download[i].click],
460
  fn=download_my_lora,
 
475
  ).success(get_repo_safetensors, [lora_repo[i]], [lora_weights[i]], queue=False, show_api=False
476
  ).success(apply_lora_prompt, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
477
  ).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
478
+
479
  for i, m in enumerate(cn_mode):
480
  gr.on(
481
  triggers=[cn_mode[i].change, cn_scale[i].change],
mod.py CHANGED
@@ -62,6 +62,21 @@ def clear_cache():
62
  gc.collect()
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def get_repo_safetensors(repo_id: str):
66
  from huggingface_hub import HfApi
67
  api = HfApi()
@@ -311,4 +326,5 @@ def enhance_prompt(input_prompt):
311
 
312
  load_prompt_enhancer.zerogpu = True
313
  fuse_loras.zerogpu = True
314
-
 
 
62
  gc.collect()
63
 
64
 
65
+ def deselect_lora():
66
+ selected_index = gr.State(None)
67
+ new_placeholder = "Type a prompt"
68
+ updated_text = ""
69
+ width = 1024
70
+ height = 1024
71
+ return (
72
+ gr.update(placeholder=new_placeholder),
73
+ updated_text,
74
+ selected_index,
75
+ width,
76
+ height,
77
+ )
78
+
79
+
80
  def get_repo_safetensors(repo_id: str):
81
  from huggingface_hub import HfApi
82
  api = HfApi()
 
326
 
327
  load_prompt_enhancer.zerogpu = True
328
  fuse_loras.zerogpu = True
329
+ preprocess_image.zerogpu = True
330
+ get_control_params.zerogpu = True