multimodalart HF staff commited on
Commit
e2bdec1
1 Parent(s): bade8d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -18
app.py CHANGED
@@ -207,6 +207,24 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
207
  del lora_model
208
  gc.collect()
209
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
211
  global last_lora, last_merged, last_fused, pipe
212
 
@@ -254,7 +272,7 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
254
  weight_name = sdxl_loras[selected_state.index]["weights"]
255
 
256
  full_path_lora = state_dicts[repo_name]["saved_name"]
257
- #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
258
  cross_attention_kwargs = None
259
  print("Last LoRA: ", last_lora)
260
  print("Current LoRA: ", repo_name)
@@ -263,7 +281,7 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
263
  if(last_fused):
264
  pipe.unfuse_lora()
265
  pipe.unload_lora_weights()
266
- pipe.load_lora_weights(full_path_lora)
267
  pipe.fuse_lora(lora_scale)
268
  last_fused = True
269
  is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
@@ -287,22 +305,8 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
287
  negative_conditioning, negative_pooled = compel(negative)
288
  else:
289
  negative_conditioning, negative_pooled = None, None
290
-
291
- image = pipe(
292
- prompt_embeds=conditioning,
293
- pooled_prompt_embeds=pooled,
294
- negative_prompt_embeds=negative_conditioning,
295
- negative_pooled_prompt_embeds=negative_pooled,
296
- width=1024,
297
- height=1024,
298
- image_embeds=face_emb,
299
- image=face_image,
300
- strength=1-image_strength,
301
- control_image=images,
302
- num_inference_steps=20,
303
- guidance_scale = guidance_scale,
304
- controlnet_conditioning_scale=[face_strength, depth_control_scale],
305
- ).images[0]
306
  last_lora = repo_name
307
  gc.collect()
308
  return image, gr.update(visible=True)
 
207
  del lora_model
208
  gc.collect()
209
  @spaces.GPU
210
+ def generate_image(conditioning, pooled, negative_conditioning, negative_pooled, face_emb, face_image, image_strength, images, guidance_scale, face_strength, depth_control_scale):
211
+ image = pipe(
212
+ prompt_embeds=conditioning,
213
+ pooled_prompt_embeds=pooled,
214
+ negative_prompt_embeds=negative_conditioning,
215
+ negative_pooled_prompt_embeds=negative_pooled,
216
+ width=1024,
217
+ height=1024,
218
+ image_embeds=face_emb,
219
+ image=face_image,
220
+ strength=1-image_strength,
221
+ control_image=images,
222
+ num_inference_steps=20,
223
+ guidance_scale = guidance_scale,
224
+ controlnet_conditioning_scale=[face_strength, depth_control_scale],
225
+ ).images[0]
226
+ return image
227
+
228
  def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
229
  global last_lora, last_merged, last_fused, pipe
230
 
 
272
  weight_name = sdxl_loras[selected_state.index]["weights"]
273
 
274
  full_path_lora = state_dicts[repo_name]["saved_name"]
275
+ loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
276
  cross_attention_kwargs = None
277
  print("Last LoRA: ", last_lora)
278
  print("Current LoRA: ", repo_name)
 
281
  if(last_fused):
282
  pipe.unfuse_lora()
283
  pipe.unload_lora_weights()
284
+ pipe.load_lora_weights(loaded_state_dict)
285
  pipe.fuse_lora(lora_scale)
286
  last_fused = True
287
  is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
 
305
  negative_conditioning, negative_pooled = compel(negative)
306
  else:
307
  negative_conditioning, negative_pooled = None, None
308
+
309
+ image = generate_image(conditioning, pooled, negative_conditioning, negative_pooled, face_emb, face_image, image_strength, images, guidance_scale, face_strength, depth_control_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  last_lora = repo_name
311
  gc.collect()
312
  return image, gr.update(visible=True)