John6666 commited on
Commit
5302530
1 Parent(s): dc1b7f5

Upload 49 files

Browse files
Files changed (2) hide show
  1. app.py +45 -6
  2. mod.py +3 -46
app.py CHANGED
@@ -5,14 +5,16 @@ import torch
5
  from PIL import Image
6
  import spaces
7
  from diffusers import DiffusionPipeline
 
 
8
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
9
  import copy
10
  import random
11
  import time
12
 
13
- from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
14
  description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
15
- get_trigger_word, enhance_prompt, pipe, controlnet, num_cns, set_control_union_image,
16
  get_control_union_mode, set_control_union_mode, get_control_params)
17
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
18
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
@@ -21,6 +23,44 @@ from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
21
  from tagger.fl2cog import predict_tags_fl2_cog
22
  from tagger.fl2flux import predict_tags_fl2_flux
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Load LoRAs from JSON file
26
  with open('loras.json', 'r') as f:
@@ -65,14 +105,14 @@ def update_selection(evt: gr.SelectData, width, height):
65
  )
66
 
67
  @spaces.GPU(duration=70)
68
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress):
69
  pipe.to("cuda")
70
  generator = torch.Generator(device="cuda").manual_seed(seed)
71
 
72
  with calculateDuration("Generating image"):
73
  # Generate image
74
  modes, images, scales = get_control_params()
75
- if not cn_on or controlnet is None or len(modes) == 0:
76
  progress(0, desc="Start Inference.")
77
  image = pipe(
78
  prompt=prompt_mash,
@@ -85,7 +125,6 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
85
  ).images[0]
86
  else:
87
  progress(0, desc="Start Inference with ControlNet.")
88
- print(modes, scales) #
89
  image = pipe(
90
  prompt=prompt_mash,
91
  control_image=images,
@@ -337,7 +376,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
337
  for i in range(num_loras):
338
  lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
339
 
340
- with gr.Accordion("ControlNet", open=True):
341
  with gr.Column():
342
  cn_on = gr.Checkbox(False, label="Use ControlNet")
343
  cn_mode = [None] * num_cns
 
5
  from PIL import Image
6
  import spaces
7
  from diffusers import DiffusionPipeline
8
+ from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
9
+ from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
10
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
11
  import copy
12
  import random
13
  import time
14
 
15
+ from mod import (models, clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists,
16
  description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
17
+ get_trigger_word, enhance_prompt, num_cns, set_control_union_image,
18
  get_control_union_mode, set_control_union_mode, get_control_params)
19
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
20
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
 
23
  from tagger.fl2cog import predict_tags_fl2_cog
24
  from tagger.fl2flux import predict_tags_fl2_flux
25
 
26
+ # Initialize the base model
27
+ base_model = models[0]
28
+ controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
29
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
30
+ last_model = models[0]
31
+ last_cn_on = False
32
+
33
+ # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
34
+ # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
35
+ def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm=True)):
36
+ global pipe
37
+ global last_model
38
+ global last_cn_on
39
+ try:
40
+ if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
41
+ if cn_on:
42
+ progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
43
+ print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
44
+ clear_cache()
45
+ controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=torch.bfloat16)
46
+ controlnet = FluxMultiControlNetModel([controlnet_union])
47
+ pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=torch.bfloat16)
48
+ last_model = repo_id
49
+ progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
50
+ print(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
51
+ else:
52
+ progress(0, desc=f"Loading model: {repo_id}")
53
+ print(f"Loading model: {repo_id}")
54
+ clear_cache()
55
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
56
+ last_model = repo_id
57
+ progress(1, desc=f"Model loaded: {repo_id}")
58
+ print(f"Model loaded: {repo_id}")
59
+ except Exception as e:
60
+ print(e)
61
+ return gr.update(visible=True)
62
+
63
+ change_base_model.zerogpu = True
64
 
65
  # Load LoRAs from JSON file
66
  with open('loras.json', 'r') as f:
 
105
  )
106
 
107
  @spaces.GPU(duration=70)
108
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
109
  pipe.to("cuda")
110
  generator = torch.Generator(device="cuda").manual_seed(seed)
111
 
112
  with calculateDuration("Generating image"):
113
  # Generate image
114
  modes, images, scales = get_control_params()
115
+ if not cn_on or len(modes) == 0:
116
  progress(0, desc="Start Inference.")
117
  image = pipe(
118
  prompt=prompt_mash,
 
125
  ).images[0]
126
  else:
127
  progress(0, desc="Start Inference with ControlNet.")
 
128
  image = pipe(
129
  prompt=prompt_mash,
130
  control_image=images,
 
376
  for i in range(num_loras):
377
  lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
378
 
379
+ with gr.Accordion("ControlNet", open=False):
380
  with gr.Column():
381
  cn_on = gr.Checkbox(False, label="Use ControlNet")
382
  cn_mode = [None] * num_cns
mod.py CHANGED
@@ -1,9 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import spaces
4
- from diffusers import DiffusionPipeline
5
- from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
6
- from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
7
  from pathlib import Path
8
  import gc
9
  import subprocess
@@ -35,16 +33,9 @@ models = [
35
 
36
  num_loras = 3
37
  num_cns = 2
38
- # Initialize the base model
39
- base_model = models[0]
40
- controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
41
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
42
- controlnet = None
43
  control_images = [None] * num_cns
44
  control_modes = [-1] * num_cns
45
  control_scales = [0] * num_cns
46
- last_model = models[0]
47
- last_cn_on = False
48
 
49
 
50
  def is_repo_name(s):
@@ -84,39 +75,6 @@ def get_repo_safetensors(repo_id: str):
84
  else: return gr.update(value=files[0], choices=files)
85
 
86
 
87
- # https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny
88
- # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
89
- # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
90
- def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm=True)):
91
- global pipe
92
- global controlnet
93
- global last_model
94
- global last_cn_on
95
- try:
96
- if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
97
- if cn_on:
98
- progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
99
- print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
100
- clear_cache()
101
- controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=torch.bfloat16)
102
- controlnet = FluxMultiControlNetModel([controlnet_union])
103
- pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=torch.bfloat16)
104
- last_model = repo_id
105
- progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
106
- print(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
107
- else:
108
- progress(0, desc=f"Loading model: {repo_id}")
109
- print(f"Loading model: {repo_id}")
110
- clear_cache()
111
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
112
- last_model = repo_id
113
- progress(1, desc=f"Model loaded: {repo_id}")
114
- print(f"Model loaded: {repo_id}")
115
- except Exception as e:
116
- print(e)
117
- return gr.update(visible=True)
118
-
119
-
120
  def expand2square(pil_img: Image.Image, background_color: tuple=(0, 0, 0)):
121
  width, height = pil_img.size
122
  if width == height:
@@ -183,7 +141,8 @@ def get_control_params():
183
 
184
 
185
  from preprocessor import Preprocessor
186
- def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int, preprocess_resolution: int):
 
187
  if control_mode == "None": return image
188
  image_resolution = max(width, height)
189
  image_before = resize_image(expand2square(image), image_resolution, image_resolution, False)
@@ -239,7 +198,6 @@ def preprocess_image(image: Image.Image, control_mode: str, height: int, width:
239
 
240
  image_after = resize_image(control_image, width, height, False)
241
  print(f"generate control image success: {image_width}x{image_height} => {width}x{height}")
242
-
243
  return image_after
244
 
245
 
@@ -347,5 +305,4 @@ def enhance_prompt(input_prompt):
347
 
348
 
349
  load_prompt_enhancer.zerogpu = True
350
- change_base_model.zerogpu = True
351
  fuse_loras.zerogpu = True
 
1
  import gradio as gr
2
  import torch
3
  import spaces
4
+
 
 
5
  from pathlib import Path
6
  import gc
7
  import subprocess
 
33
 
34
  num_loras = 3
35
  num_cns = 2
 
 
 
 
 
36
  control_images = [None] * num_cns
37
  control_modes = [-1] * num_cns
38
  control_scales = [0] * num_cns
 
 
39
 
40
 
41
  def is_repo_name(s):
 
75
  else: return gr.update(value=files[0], choices=files)
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def expand2square(pil_img: Image.Image, background_color: tuple=(0, 0, 0)):
79
  width, height = pil_img.size
80
  if width == height:
 
141
 
142
 
143
  from preprocessor import Preprocessor
144
+ def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int,
145
+ preprocess_resolution: int, progress=gr.Progress(track_tqdm=True)):
146
  if control_mode == "None": return image
147
  image_resolution = max(width, height)
148
  image_before = resize_image(expand2square(image), image_resolution, image_resolution, False)
 
198
 
199
  image_after = resize_image(control_image, width, height, False)
200
  print(f"generate control image success: {image_width}x{image_height} => {width}x{height}")
 
201
  return image_after
202
 
203
 
 
305
 
306
 
307
  load_prompt_enhancer.zerogpu = True
 
308
  fuse_loras.zerogpu = True