fffiloni commited on
Commit
5c98b7c
1 Parent(s): 86b3a94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -6
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
- from huggingface_hub import login
 
3
  from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
4
  import torch
5
  import copy
@@ -11,8 +12,16 @@ is_shared_ui = True if "fffiloni/sd-xl-lora-fusion" in os.environ['SPACE_ID'] el
11
  hf_token = os.environ.get("HF_TOKEN")
12
  login(token = hf_token)
13
 
 
 
14
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
15
 
 
 
 
 
 
 
16
  @spaces.GPU
17
  def infer(lora_1_id, lora_1_sfts, lora_2_id, lora_2_sfts, prompt, negative_prompt, lora_1_scale, lora_2_scale, seed):
18
 
@@ -141,16 +150,20 @@ with gr.Blocks(css=css) as demo:
141
 
142
  with gr.Column():
143
 
144
- lora_2_id = gr.Textbox(
145
  label = "LoRa 2 ID",
146
- placeholder = "username/model_id"
 
147
  )
148
 
149
- lora_2_sfts = gr.Textbox(
150
  label = "Safetensors file",
151
- placeholder = "specific_chosen.safetensors"
 
152
  )
153
 
 
 
154
  # PART 2 • INFERENCE
155
  with gr.Row():
156
 
@@ -160,7 +173,7 @@ with gr.Blocks(css=css) as demo:
160
  placeholde = "e.g: a triggerWordOne portrait in triggerWord2 style"
161
  )
162
 
163
- run_btn = gr.Button("Run")
164
 
165
  output_image = gr.Image(
166
  label = "Output"
@@ -205,6 +218,17 @@ with gr.Blocks(css=css) as demo:
205
  )
206
 
207
  # ACTIONS
 
 
 
 
 
 
 
 
 
 
 
208
  run_btn.click(
209
  fn = infer,
210
  inputs = [
 
1
  import gradio as gr
2
+ from huggingface_hub import login, HfFileSystem
3
+
4
  from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
5
  import torch
6
  import copy
 
12
  hf_token = os.environ.get("HF_TOKEN")
13
  login(token = hf_token)
14
 
15
+ fs = HfFileSystem(token=hf_token)
16
+
17
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
18
 
19
+ def load_sfts(repo_1_id, repo_2_id):
20
+ # List all ".safetensors" files in repos
21
+ sfts_available_files_1 = fs.glob(f"{repo_1_id}/**.safetensors")
22
+ sfts_available_files_2 = fs.glob(f"{repo_2_id}/**.safetensors")
23
+ return gr.update(choices=sfts_available_files_1, value=sfts_available_files_1[0], visible=True), gr.update(choices=sfts_available_files_2, value=sfts_available_files_2[0], visible=True)
24
+
25
  @spaces.GPU
26
  def infer(lora_1_id, lora_1_sfts, lora_2_id, lora_2_sfts, prompt, negative_prompt, lora_1_scale, lora_2_scale, seed):
27
 
 
150
 
151
  with gr.Column():
152
 
153
+ lora_2_id = gr.Dropdown(
154
  label = "LoRa 2 ID",
155
+ placeholder = "username/model_id",
156
+ visible=False
157
  )
158
 
159
+ lora_2_sfts = gr.Dropdown(
160
  label = "Safetensors file",
161
+ placeholder = "specific_chosen.safetensors",
162
+ visible=False
163
  )
164
 
165
+ load_models_btn = gr.Button("Load models and .safetensors")
166
+
167
  # PART 2 • INFERENCE
168
  with gr.Row():
169
 
 
173
  placeholde = "e.g: a triggerWordOne portrait in triggerWord2 style"
174
  )
175
 
176
+ run_btn = gr.Button("Run")
177
 
178
  output_image = gr.Image(
179
  label = "Output"
 
218
  )
219
 
220
  # ACTIONS
221
+ load_models_btn.click(
222
+ fn = load_sfts,
223
+ inputs = [
224
+ lora_1_id,
225
+ lora_2_id
226
+ ],
227
+ outputs = [
228
+ lora_1_sfts,
229
+ lora_2_sfts
230
+ ]
231
+ )
232
  run_btn.click(
233
  fn = infer,
234
  inputs = [