jiuface commited on
Commit
b8cbb2a
1 Parent(s): 4d73de3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -30
app.py CHANGED
@@ -11,8 +11,6 @@ from huggingface_hub import login
11
  import time
12
  from datetime import datetime
13
  from io import BytesIO
14
- # from diffusers.models.attention_processor import AttentionProcessor
15
- from diffusers.models.attention_processor import AttnProcessor2_0
16
  import torch.nn.functional as F
17
  import time
18
  import boto3
@@ -102,41 +100,42 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
102
 
103
  def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
104
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
105
-
106
  # Load LoRA weights
 
107
  if lora_strings_json:
108
  try:
109
  lora_configs = json.loads(lora_strings_json)
110
  except:
111
- lora_configs = None
112
  print("parse lora config json failed")
113
 
114
- if lora_configs:
115
- with calculateDuration("Loading LoRA weights"):
116
- active_adapters = pipe.get_active_adapters()
117
- print("get_active_adapters", active_adapters)
118
- adapter_names = []
119
- adapter_weights = []
120
- for lora_info in lora_configs:
121
- lora_repo = lora_info.get("repo")
122
- weights = lora_info.get("weights")
123
- adapter_name = lora_info.get("adapter_name")
124
- adapter_weight = lora_info.get("adapter_weight")
125
- if adapter_name in active_adapters:
126
- print(f"Adapter '{adapter_name}' is already loaded, skipping.")
127
- continue
128
- if lora_repo and weights and adapter_name:
129
- # load lora
130
- try:
131
- pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
132
- except ValueError as e:
133
- print(f"Error loading LoRA adapter: {e}")
134
  continue
135
- adapter_names.append(adapter_name)
136
- adapter_weights.append(adapter_weight)
137
- # set lora weights
138
- if len(adapter_names) > 0:
139
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
 
 
 
 
 
 
 
140
 
141
  # Set random seed for reproducibility
142
  if randomize_seed:
@@ -150,6 +149,7 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
150
  final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
151
  except Exception as e:
152
  error_message = str(e)
 
153
  print("Run error", e)
154
  final_image = None
155
 
@@ -162,7 +162,8 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
162
  result = {"status": "success", "message": "Image generated but not uploaded"}
163
  else:
164
  result = {"status": "failed", "message": error_message}
165
-
 
166
  progress(100, "Completed!")
167
 
168
  return final_image, seed, json.dumps(result)
 
11
  import time
12
  from datetime import datetime
13
  from io import BytesIO
 
 
14
  import torch.nn.functional as F
15
  import time
16
  import boto3
 
100
 
101
  def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
102
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
103
+ gr.Info("Starting process")
104
  # Load LoRA weights
105
+ lora_configs = None
106
  if lora_strings_json:
107
  try:
108
  lora_configs = json.loads(lora_strings_json)
109
  except:
110
+ gr.Warning("Parse lora config json failed")
111
  print("parse lora config json failed")
112
 
113
+ if lora_configs:
114
+ with calculateDuration("Loading LoRA weights"):
115
+ active_adapters = pipe.get_active_adapters()
116
+ print("get_active_adapters", active_adapters)
117
+ adapter_names = []
118
+ adapter_weights = []
119
+ for lora_info in lora_configs:
120
+ lora_repo = lora_info.get("repo")
121
+ weights = lora_info.get("weights")
122
+ adapter_name = lora_info.get("adapter_name")
123
+ adapter_weight = lora_info.get("adapter_weight")
124
+ if adapter_name in active_adapters:
125
+ print(f"Adapter '{adapter_name}' is already loaded, skipping.")
 
 
 
 
 
 
 
126
  continue
127
+ if lora_repo and weights and adapter_name:
128
+ # load lora
129
+ try:
130
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
131
+ except ValueError as e:
132
+ print(f"Error loading LoRA adapter: {e}")
133
+ continue
134
+ adapter_names.append(adapter_name)
135
+ adapter_weights.append(adapter_weight)
136
+ # set lora weights
137
+ if len(adapter_names) > 0:
138
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
139
 
140
  # Set random seed for reproducibility
141
  if randomize_seed:
 
149
  final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
150
  except Exception as e:
151
  error_message = str(e)
152
+ gr.Error(error_message)
153
  print("Run error", e)
154
  final_image = None
155
 
 
162
  result = {"status": "success", "message": "Image generated but not uploaded"}
163
  else:
164
  result = {"status": "failed", "message": error_message}
165
+
166
+ gr.Info("Completed!")
167
  progress(100, "Completed!")
168
 
169
  return final_image, seed, json.dumps(result)