jiuface commited on
Commit
cdeb4dc
1 Parent(s): 3b7f155

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -45
app.py CHANGED
@@ -57,55 +57,14 @@ class calculateDuration:
57
 
58
  @spaces.GPU(duration=120)
59
  @torch.inference_mode()
60
- def generate_image(prompt, lora_strings_json, steps, seed, cfg_scale, width, height, progress):
61
-
62
- lora_configs = None
63
- adapter_names = []
64
- if lora_strings_json:
65
- try:
66
- lora_configs = json.loads(lora_strings_json)
67
- except:
68
- gr.Warning("Parse lora config json failed")
69
- print("parse lora config json failed")
70
-
71
- if lora_configs:
72
- with calculateDuration("Loading LoRA weights"):
73
- adapter_weights = []
74
- for lora_info in lora_configs:
75
- lora_repo = lora_info.get("repo")
76
- weights = lora_info.get("weights")
77
- adapter_name = lora_info.get("adapter_name")
78
- adapter_weight = lora_info.get("adapter_weight")
79
-
80
- if lora_repo and weights and adapter_name:
81
- retry_count = 3
82
- for attempt in range(retry_count):
83
- try:
84
- pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
85
- adapter_names.append(adapter_name)
86
- adapter_weights.append(adapter_weight)
87
- break # Load successful, exit retry loop
88
- except ValueError as e:
89
- print(f"Attempt {attempt+1}/{retry_count} failed to load LoRA adapter: {e}")
90
- if attempt == retry_count - 1:
91
- print(f"Error loading LoRA adapter: {adapter_name} after {retry_count} attempts")
92
- else:
93
- time.sleep(1) # Wait before retrying
94
-
95
- # set lora weights
96
- if len(adapter_names) > 0:
97
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
98
 
99
 
100
  gr.Info("Start to generate images ...")
101
-
102
  with calculateDuration(f"Make a new generator:{seed}"):
103
  pipe.to(device)
104
  generator = torch.Generator(device=device).manual_seed(seed)
105
 
106
- if len(adapter_names) > 0:
107
- pipe.fuse_lora(adapter_names=adapter_names)
108
-
109
  with calculateDuration("Generating image"):
110
  # Generate image
111
  generated_image = pipe(
@@ -119,8 +78,6 @@ def generate_image(prompt, lora_strings_json, steps, seed, cfg_scale, width, he
119
  ).images[0]
120
 
121
  progress(99, "Generate image success!")
122
- if len(adapter_names) > 0:
123
- pipe.unfuse_lora()
124
  return generated_image
125
 
126
 
@@ -173,12 +130,49 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
173
 
174
  # Load LoRA weights
175
  gr.Info("Start to load loras ...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  # Generate image
178
  error_message = ""
179
  try:
180
  print("Start applying for zeroGPU resources")
181
- final_image = generate_image(prompt, lora_strings_json, steps, seed, cfg_scale, width, height, progress)
182
  except Exception as e:
183
  error_message = str(e)
184
  gr.Error(error_message)
 
57
 
58
  @spaces.GPU(duration=120)
59
  @torch.inference_mode()
60
+ def generate_image(prompt, adapter_names, steps, seed, cfg_scale, width, height, progress):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  gr.Info("Start to generate images ...")
 
64
  with calculateDuration(f"Make a new generator:{seed}"):
65
  pipe.to(device)
66
  generator = torch.Generator(device=device).manual_seed(seed)
67
 
 
 
 
68
  with calculateDuration("Generating image"):
69
  # Generate image
70
  generated_image = pipe(
 
78
  ).images[0]
79
 
80
  progress(99, "Generate image success!")
 
 
81
  return generated_image
82
 
83
 
 
130
 
131
  # Load LoRA weights
132
  gr.Info("Start to load loras ...")
133
+ lora_configs = None
134
+ adapter_names = []
135
+ if lora_strings_json:
136
+ try:
137
+ lora_configs = json.loads(lora_strings_json)
138
+ except:
139
+ gr.Warning("Parse lora config json failed")
140
+ print("parse lora config json failed")
141
+
142
+ if lora_configs:
143
+ with calculateDuration("Loading LoRA weights"):
144
+ adapter_weights = []
145
+ for lora_info in lora_configs:
146
+ lora_repo = lora_info.get("repo")
147
+ weights = lora_info.get("weights")
148
+ adapter_name = lora_info.get("adapter_name")
149
+ adapter_weight = lora_info.get("adapter_weight")
150
+
151
+ if lora_repo and weights and adapter_name:
152
+ retry_count = 3
153
+ for attempt in range(retry_count):
154
+ try:
155
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
156
+ adapter_names.append(adapter_name)
157
+ adapter_weights.append(adapter_weight)
158
+ break # Load successful, exit retry loop
159
+ except ValueError as e:
160
+ print(f"Attempt {attempt+1}/{retry_count} failed to load LoRA adapter: {e}")
161
+ if attempt == retry_count - 1:
162
+ print(f"Error loading LoRA adapter: {adapter_name} after {retry_count} attempts")
163
+ else:
164
+ time.sleep(1) # Wait before retrying
165
+
166
+ # set lora weights
167
+ if len(adapter_names) > 0:
168
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
169
+
170
 
171
  # Generate image
172
  error_message = ""
173
  try:
174
  print("Start applying for zeroGPU resources")
175
+ final_image = generate_image(prompt, adapter_names, steps, seed, cfg_scale, width, height, progress)
176
  except Exception as e:
177
  error_message = str(e)
178
  gr.Error(error_message)