John6666 commited on
Commit
1d22096
1 Parent(s): db65c96

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. mod.py +4 -1
app.py CHANGED
@@ -221,7 +221,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_width=True, css=css) as app:
221
  triggers=[generate_button.click, prompt.submit],
222
  fn=change_base_model,
223
  inputs=[model_name],
224
- outputs=None
225
  ).success(
226
  fn=run_lora,
227
  inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
@@ -229,7 +229,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_width=True, css=css) as app:
229
  outputs=[result, seed]
230
  )
231
 
232
- model_name.change(change_base_model, [model_name], None)
233
 
234
  gr.on(
235
  triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
 
221
  triggers=[generate_button.click, prompt.submit],
222
  fn=change_base_model,
223
  inputs=[model_name],
224
+ outputs=[result]
225
  ).success(
226
  fn=run_lora,
227
  inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
 
229
  outputs=[result, seed]
230
  )
231
 
232
+ model_name.change(change_base_model, [model_name], [result])
233
 
234
  gr.on(
235
  triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit],
mod.py CHANGED
@@ -73,14 +73,17 @@ base_model = models[0]
73
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
74
 
75
 
76
- def change_base_model(repo_id: str):
77
  global pipe
78
  try:
79
  if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
 
80
  clear_cache()
81
  pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
 
82
  except Exception as e:
83
  print(e)
 
84
 
85
 
86
  def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):
 
73
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
74
 
75
 
76
+ def change_base_model(repo_id: str, progress=gr.Progress(track_tqdm=True)):
77
  global pipe
78
  try:
79
  if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
80
+ progress(0, f"Loading model: {repo_id}")
81
  clear_cache()
82
  pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
83
+ progress(1, f"Model loaded: {repo_id}")
84
  except Exception as e:
85
  print(e)
86
+ return gr.update(visible=True)
87
 
88
 
89
  def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):