multimodalart HF staff commited on
Commit
a1a833d
1 Parent(s): 7c6dd97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -91,26 +91,29 @@ original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-
91
 
92
  @spaces.GPU
93
  def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1):
94
- print("Run this?")
95
  repo_id_1 = shuffled_items[0]['repo']
96
  repo_id_2 = shuffled_items[1]['repo']
97
  print("Loading state dicts...")
98
- state_dict_1 = state_dicts[repo_id_1]["state_dict"]
99
- state_dict_2 = state_dicts[repo_id_2]["state_dict"]
100
- print("Loaded state dicts.")
 
 
101
  #pipe = copy.deepcopy(original_pipe)
102
- # Time for pickle
103
- start_time = time.time()
104
  pipe = pickle.loads(pickle.dumps(original_pipe))
105
- pickle_time = time.time() - start_time
106
  print(f"Pickle time: {pickle_time}")
107
  pipe.to("cuda")
 
108
  print("Loading LoRA weights...")
109
  pipe.load_lora_weights(state_dict_1)
110
  pipe.fuse_lora(lora_1_scale)
111
  pipe.load_lora_weights(state_dict_2)
112
  pipe.fuse_lora(lora_2_scale)
113
- print("Loaded LoRA weights.")
 
114
  if negative_prompt == "":
115
  negative_prompt = None
116
 
 
91
 
92
  @spaces.GPU
93
  def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1):
94
+
95
  repo_id_1 = shuffled_items[0]['repo']
96
  repo_id_2 = shuffled_items[1]['repo']
97
  print("Loading state dicts...")
98
+ start_time = time()
99
+ state_dict_1 = pickle.loads(pickle.dumps(state_dicts[repo_id_1]["state_dict"]))
100
+ state_dict_2 = pickle.loads(pickle.dumps(state_dicts[repo_id_2]["state_dict"]))
101
+ state_dict_time = time() - start_time
102
+ print(f"State Dict time: {state_dict_time}")
103
  #pipe = copy.deepcopy(original_pipe)
104
+ start_time = time()
 
105
  pipe = pickle.loads(pickle.dumps(original_pipe))
106
+ pickle_time = time() - start_time
107
  print(f"Pickle time: {pickle_time}")
108
  pipe.to("cuda")
109
+ start_time = time()
110
  print("Loading LoRA weights...")
111
  pipe.load_lora_weights(state_dict_1)
112
  pipe.fuse_lora(lora_1_scale)
113
  pipe.load_lora_weights(state_dict_2)
114
  pipe.fuse_lora(lora_2_scale)
115
+ lora_time = time() - start_time
116
+ print(f"Loaded LoRAs time: {lora_time}")
117
  if negative_prompt == "":
118
  negative_prompt = None
119