John6666 commited on
Commit
c60bd9d
1 Parent(s): 0953761

Upload 5 files

Browse files
Files changed (3) hide show
  1. app.py +5 -2
  2. mod.py +31 -4
  3. requirements.txt +6 -1
app.py CHANGED
@@ -11,7 +11,8 @@ import random
11
  import time
12
 
13
  from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
14
- description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras, get_trigger_word, pipe)
 
15
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
16
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
17
  update_loras)
@@ -241,6 +242,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
241
  tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use CogFlorence-2.1-Large", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
242
  tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
243
  prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt")
 
244
  with gr.Column(scale=1, elem_id="gen_column"):
245
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
246
  with gr.Row():
@@ -306,8 +308,8 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
306
  with gr.Accordion("From URL", open=True, visible=True):
307
  with gr.Row():
308
  lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
309
- lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
310
  lora_search_civitai_submit = gr.Button("Search on Civitai")
 
311
  lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
312
  lora_search_civitai_json = gr.JSON(value={}, visible=False)
313
  lora_search_civitai_desc = gr.Markdown(value="", visible=False)
@@ -344,6 +346,7 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
344
  )
345
 
346
  model_name.change(change_base_model, [model_name], [result])
 
347
 
348
  gr.on(
349
  triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
 
11
  import time
12
 
13
  from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
14
+ description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
15
+ get_trigger_word, pipe, enhance_prompt)
16
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
17
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
18
  update_loras)
 
242
  tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use CogFlorence-2.1-Large", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
243
  tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
244
  prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt")
245
+ prompt_enhance = gr.Button(value="Enhance your prompt", variant="secondary")
246
  with gr.Column(scale=1, elem_id="gen_column"):
247
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
248
  with gr.Row():
 
308
  with gr.Accordion("From URL", open=True, visible=True):
309
  with gr.Row():
310
  lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
 
311
  lora_search_civitai_submit = gr.Button("Search on Civitai")
312
+ lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
313
  lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
314
  lora_search_civitai_json = gr.JSON(value={}, visible=False)
315
  lora_search_civitai_desc = gr.Markdown(value="", visible=False)
 
346
  )
347
 
348
  model_name.change(change_base_model, [model_name], [result])
349
+ prompt_enhance.click(enhance_prompt, [prompt], [prompt])
350
 
351
  gr.on(
352
  triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
mod.py CHANGED
@@ -7,6 +7,7 @@ import gc
7
  import subprocess
8
 
9
 
 
10
  subprocess.run('pip cache purge', shell=True)
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  torch.set_grad_enabled(False)
@@ -61,7 +62,7 @@ def get_repo_safetensors(repo_id: str):
61
  if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
62
  files = api.list_repo_files(repo_id=repo_id)
63
  except Exception as e:
64
- print(f"Error: Failed to get {repo_id}'s info. ")
65
  print(e)
66
  return gr.update(choices=[])
67
  files = [f for f in files if f.endswith(".safetensors")]
@@ -138,8 +139,7 @@ def fuse_loras(pipe, lorajson: list[dict]):
138
  #pipe.unload_lora_weights()
139
 
140
 
141
- change_base_model.zerogpu = True
142
- fuse_loras.zerogpu = True
143
 
144
 
145
  def description_ui():
@@ -148,4 +148,31 @@ def description_ui():
148
  - Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
149
  [gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
150
  """
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import subprocess
8
 
9
 
10
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
11
  subprocess.run('pip cache purge', shell=True)
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  torch.set_grad_enabled(False)
 
62
  if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
63
  files = api.list_repo_files(repo_id=repo_id)
64
  except Exception as e:
65
+ print(f"Error: Failed to get {repo_id}'s info.")
66
  print(e)
67
  return gr.update(choices=[])
68
  files = [f for f in files if f.endswith(".safetensors")]
 
139
  #pipe.unload_lora_weights()
140
 
141
 
142
+
 
143
 
144
 
145
  def description_ui():
 
148
  - Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
149
  [gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
150
  """
151
+ )
152
+
153
+
154
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
155
+ def load_prompt_enhancer():
156
+ try:
157
+ model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
158
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
159
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
160
+ enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
161
+ except Exception as e:
162
+ print(e)
163
+ enhancer_flux = None
164
+ return enhancer_flux
165
+
166
+
167
+ enhancer_flux = load_prompt_enhancer()
168
+
169
+
170
+ def enhance_prompt(input_prompt):
171
+ result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
172
+ enhanced_text = result[0]['generated_text']
173
+ return enhanced_text
174
+
175
+
176
+ load_prompt_enhancer.zerogpu = True
177
+ change_base_model.zerogpu = True
178
+ fuse_loras.zerogpu = True
requirements.txt CHANGED
@@ -1,7 +1,12 @@
1
  torch
 
 
 
2
  git+https://github.com/huggingface/diffusers
3
  spaces
4
  transformers
5
  peft
6
  sentencepiece
7
- timm
 
 
 
1
  torch
2
+ torchvision
3
+ huggingface_hub
4
+ accelerate
5
  git+https://github.com/huggingface/diffusers
6
  spaces
7
  transformers
8
  peft
9
  sentencepiece
10
+ timm
11
+ xformers
12
+ einops