Spaces:
Running
on
Zero
Running
on
Zero
pure load lora
Browse files
app.py
CHANGED
@@ -14,9 +14,12 @@ from io import BytesIO
|
|
14 |
# from diffusers.models.attention_processor import AttentionProcessor
|
15 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
16 |
import torch.nn.functional as F
|
17 |
-
|
|
|
|
|
18 |
import re
|
19 |
import json
|
|
|
20 |
# 登录 Hugging Face Hub
|
21 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
22 |
login(token=HF_TOKEN)
|
@@ -49,262 +52,16 @@ class calculateDuration:
|
|
49 |
else:
|
50 |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
51 |
|
52 |
-
# 定义位置、偏移和区域的映射
|
53 |
-
valid_locations = { # x, y in 90*90
|
54 |
-
'in the center': (45, 45),
|
55 |
-
'on the left': (15, 45),
|
56 |
-
'on the right': (75, 45),
|
57 |
-
'on the top': (45, 15),
|
58 |
-
'on the bottom': (45, 75),
|
59 |
-
'on the top-left': (15, 15),
|
60 |
-
'on the top-right': (75, 15),
|
61 |
-
'on the bottom-left': (15, 75),
|
62 |
-
'on the bottom-right': (75, 75)
|
63 |
-
}
|
64 |
-
|
65 |
-
valid_offsets = { # x, y in 90*90
|
66 |
-
'no offset': (0, 0),
|
67 |
-
'slightly to the left': (-10, 0),
|
68 |
-
'slightly to the right': (10, 0),
|
69 |
-
'slightly to the upper': (0, -10),
|
70 |
-
'slightly to the lower': (0, 10),
|
71 |
-
'slightly to the upper-left': (-10, -10),
|
72 |
-
'slightly to the upper-right': (10, -10),
|
73 |
-
'slightly to the lower-left': (-10, 10),
|
74 |
-
'slightly to the lower-right': (10, 10)
|
75 |
-
}
|
76 |
-
|
77 |
-
valid_areas = { # w, h in 90*90
|
78 |
-
"a small square area": (50, 50),
|
79 |
-
"a small vertical area": (40, 60),
|
80 |
-
"a small horizontal area": (60, 40),
|
81 |
-
"a medium-sized square area": (60, 60),
|
82 |
-
"a medium-sized vertical area": (50, 80),
|
83 |
-
"a medium-sized horizontal area": (80, 50),
|
84 |
-
"a large square area": (70, 70),
|
85 |
-
"a large vertical area": (60, 90),
|
86 |
-
"a large horizontal area": (90, 60)
|
87 |
-
}
|
88 |
-
|
89 |
-
# 解析角色位置的函数
|
90 |
-
def parse_character_position(character_position):
|
91 |
-
# 定义正则表达式模式
|
92 |
-
location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys())
|
93 |
-
offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys())
|
94 |
-
area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys())
|
95 |
-
|
96 |
-
# 提取位置
|
97 |
-
location_match = re.search(location_pattern, character_position, re.IGNORECASE)
|
98 |
-
location = location_match.group(0) if location_match else 'in the center'
|
99 |
-
|
100 |
-
# 提取偏移
|
101 |
-
offset_match = re.search(offset_pattern, character_position, re.IGNORECASE)
|
102 |
-
offset = offset_match.group(0) if offset_match else 'no offset'
|
103 |
-
|
104 |
-
# 提取区域
|
105 |
-
area_match = re.search(area_pattern, character_position, re.IGNORECASE)
|
106 |
-
area = area_match.group(0) if area_match else 'a medium-sized square area'
|
107 |
-
|
108 |
-
return {
|
109 |
-
'location': location,
|
110 |
-
'offset': offset,
|
111 |
-
'area': area
|
112 |
-
}
|
113 |
-
|
114 |
-
# 创建掩码的函数
|
115 |
-
def create_attention_mask(image_width, image_height, location, offset, area):
|
116 |
-
# 图像在生成时通常会被缩放为 90x90,因此先定义一个基础尺寸
|
117 |
-
base_size = 90
|
118 |
-
|
119 |
-
# 获取位置坐标
|
120 |
-
loc_x, loc_y = valid_locations.get(location, (45, 45))
|
121 |
-
# 获取偏移量
|
122 |
-
offset_x, offset_y = valid_offsets.get(offset, (0, 0))
|
123 |
-
# 获取区域大小
|
124 |
-
area_width, area_height = valid_areas.get(area, (60, 60))
|
125 |
-
|
126 |
-
# 计算最终位置
|
127 |
-
final_x = loc_x + offset_x
|
128 |
-
final_y = loc_y + offset_y
|
129 |
-
|
130 |
-
# 将坐标和尺寸映射到实际图像尺寸
|
131 |
-
scale_x = image_width / base_size
|
132 |
-
scale_y = image_height / base_size
|
133 |
-
|
134 |
-
center_x = final_x * scale_x
|
135 |
-
center_y = final_y * scale_y
|
136 |
-
width = area_width * scale_x
|
137 |
-
height = area_height * scale_y
|
138 |
-
|
139 |
-
# 计算左上角和右下角坐标
|
140 |
-
x_start = int(max(center_x - width / 2, 0))
|
141 |
-
y_start = int(max(center_y - height / 2, 0))
|
142 |
-
x_end = int(min(center_x + width / 2, image_width))
|
143 |
-
y_end = int(min(center_y + height / 2, image_height))
|
144 |
-
|
145 |
-
# 创建掩码
|
146 |
-
mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda")
|
147 |
-
mask[y_start:y_end, x_start:x_end] = 1.0
|
148 |
-
|
149 |
-
# 展平成一维
|
150 |
-
mask_flat = mask.view(-1) # 形状为 (image_height * image_width,)
|
151 |
-
return mask_flat
|
152 |
-
|
153 |
-
# 自定义注意力处理器
|
154 |
-
|
155 |
-
class CustomCrossAttentionProcessor(AttnProcessor2_0):
|
156 |
-
def __init__(self, masks, adapter_names):
|
157 |
-
super().__init__()
|
158 |
-
self.masks = masks # 列表,包含每个角色的掩码 (shape: [key_length])
|
159 |
-
self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称
|
160 |
-
|
161 |
-
def __call__(
|
162 |
-
self,
|
163 |
-
attn,
|
164 |
-
hidden_states,
|
165 |
-
encoder_hidden_states=None,
|
166 |
-
attention_mask=None,
|
167 |
-
temb=None,
|
168 |
-
**kwargs,
|
169 |
-
):
|
170 |
-
"""
|
171 |
-
自定义的注意力处理器,用于在注意力计算中应用角色掩码。
|
172 |
-
|
173 |
-
参数:
|
174 |
-
attn: 注意力模块实例。
|
175 |
-
hidden_states: 输入的隐藏状态 (query)。
|
176 |
-
encoder_hidden_states: 编码器的隐藏状态 (key/value)。
|
177 |
-
attention_mask: 注意力掩码。
|
178 |
-
temb: 时间嵌入(可能不需要)。
|
179 |
-
**kwargs: 其他参数。
|
180 |
-
|
181 |
-
返回:
|
182 |
-
��理后的隐藏状态。
|
183 |
-
"""
|
184 |
-
# 获取当前的 adapter_name
|
185 |
-
adapter_name = getattr(attn, 'adapter_name', None)
|
186 |
-
if adapter_name is None or adapter_name not in self.adapter_names:
|
187 |
-
# 如果没有 adapter_name,或者不在我们的列表中,直接执行父类的 __call__ 方法
|
188 |
-
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb, **kwargs)
|
189 |
-
|
190 |
-
# 查找 adapter_name 对应的索引
|
191 |
-
idx = self.adapter_names.index(adapter_name)
|
192 |
-
mask = self.masks[idx] # 获取对应的掩码 (shape: [key_length])
|
193 |
-
|
194 |
-
# 以下是 AttnProcessor2_0 的实现,我们在适当的位置加入自定义的掩码逻辑
|
195 |
-
|
196 |
-
residual = hidden_states
|
197 |
-
if attn.spatial_norm is not None:
|
198 |
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
199 |
-
|
200 |
-
input_ndim = hidden_states.ndim
|
201 |
-
|
202 |
-
if input_ndim == 4:
|
203 |
-
batch_size, channel, height, width = hidden_states.shape
|
204 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
205 |
-
else:
|
206 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
207 |
-
|
208 |
-
if encoder_hidden_states is None:
|
209 |
-
encoder_hidden_states = hidden_states
|
210 |
-
else:
|
211 |
-
# 如果有 encoder_hidden_states,获取其形状
|
212 |
-
encoder_batch_size, key_length, _ = encoder_hidden_states.shape
|
213 |
-
|
214 |
-
if attention_mask is not None:
|
215 |
-
# 处理 attention_mask,如果需要的话
|
216 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
|
217 |
-
# attention_mask 的形状应为 (batch_size, attn.heads, query_length, key_length)
|
218 |
-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
219 |
-
else:
|
220 |
-
# 如果没有 attention_mask,我们创建一个全 0 的掩码
|
221 |
-
attention_mask = torch.zeros(
|
222 |
-
batch_size, attn.heads, 1, key_length, device=hidden_states.device, dtype=hidden_states.dtype
|
223 |
-
)
|
224 |
-
|
225 |
-
if attn.group_norm is not None:
|
226 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
227 |
-
|
228 |
-
query = attn.to_q(hidden_states)
|
229 |
-
|
230 |
-
if attn.norm_cross:
|
231 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
232 |
-
|
233 |
-
key = attn.to_k(encoder_hidden_states)
|
234 |
-
value = attn.to_v(encoder_hidden_states)
|
235 |
-
|
236 |
-
inner_dim = key.shape[-1]
|
237 |
-
head_dim = inner_dim // attn.heads
|
238 |
-
|
239 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
240 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
241 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
242 |
-
|
243 |
-
if attn.norm_q is not None:
|
244 |
-
query = attn.norm_q(query)
|
245 |
-
if attn.norm_k is not None:
|
246 |
-
key = attn.norm_k(key)
|
247 |
-
|
248 |
-
# 计算原始的注意力得分
|
249 |
-
# 我们需要在计算注意力得分前应用掩码
|
250 |
-
# 但由于 PyTorch 的 scaled_dot_product_attention 接受 attention_mask 参数,我们需要调整我们的掩码
|
251 |
-
|
252 |
-
# 创建自定义的 attention_mask
|
253 |
-
# mask 的形状为 [key_length],需要调整为 (batch_size, attn.heads, 1, key_length)
|
254 |
-
custom_attention_mask = mask.view(1, 1, 1, -1).to(hidden_states.device, dtype=hidden_states.dtype)
|
255 |
-
# 将有效位置设为 0,被掩蔽的位置设为 -1e9(对于 float16,使用 -65504)
|
256 |
-
mask_value = -65504.0 if hidden_states.dtype == torch.float16 else -1e9
|
257 |
-
custom_attention_mask = (1.0 - custom_attention_mask) * mask_value # 有效位置为 0,无效位置为 -1e9
|
258 |
-
|
259 |
-
# 将自定义掩码添加到 attention_mask
|
260 |
-
attention_mask = attention_mask + custom_attention_mask
|
261 |
-
|
262 |
-
# 计算注意力
|
263 |
-
hidden_states = F.scaled_dot_product_attention(
|
264 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
265 |
-
)
|
266 |
-
|
267 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
268 |
-
hidden_states = hidden_states.to(query.dtype)
|
269 |
-
|
270 |
-
# linear proj
|
271 |
-
hidden_states = attn.to_out[0](hidden_states)
|
272 |
-
# dropout
|
273 |
-
hidden_states = attn.to_out[1](hidden_states)
|
274 |
-
|
275 |
-
if input_ndim == 4:
|
276 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
277 |
-
|
278 |
-
if attn.residual_connection:
|
279 |
-
hidden_states = hidden_states + residual
|
280 |
-
|
281 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
282 |
-
|
283 |
-
return hidden_states
|
284 |
-
|
285 |
-
|
286 |
-
# 替换注意力处理器的函数
|
287 |
-
def replace_attention_processors(pipe, masks, adapter_names):
|
288 |
-
custom_processor = CustomCrossAttentionProcessor(masks, adapter_names)
|
289 |
-
for name, module in pipe.transformer.named_modules():
|
290 |
-
if hasattr(module, 'attn'):
|
291 |
-
module.attn.adapter_name = getattr(module, 'adapter_name', None)
|
292 |
-
module.attn.processor = custom_processor
|
293 |
-
if hasattr(module, 'cross_attn'):
|
294 |
-
module.cross_attn.adapter_name = getattr(module, 'adapter_name', None)
|
295 |
-
module.cross_attn.processor = custom_processor
|
296 |
-
|
297 |
# 生成图像的函数
|
298 |
-
|
299 |
-
|
|
|
300 |
pipe.to(device)
|
301 |
generator = torch.Generator(device=device).manual_seed(seed)
|
302 |
-
|
303 |
with calculateDuration("Generating image"):
|
304 |
# Generate image
|
305 |
generated_image = pipe(
|
306 |
-
|
307 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
308 |
num_inference_steps=steps,
|
309 |
guidance_scale=cfg_scale,
|
310 |
width=width,
|
@@ -315,111 +72,67 @@ def generate_image_with_embeddings(prompt_embeds, pooled_prompt_embeds, steps, s
|
|
315 |
progress(99, "Generate success!")
|
316 |
return generated_image
|
317 |
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
|
|
337 |
|
338 |
-
|
339 |
-
with calculateDuration("Loading LoRA weights"):
|
340 |
-
pipe.unload_lora_weights()
|
341 |
-
adapter_names = []
|
342 |
-
for lora_info in lora_strings:
|
343 |
-
lora_repo = lora_info.get("repo")
|
344 |
-
weights = lora_info.get("weights")
|
345 |
-
adapter_name = lora_info.get("adapter_name")
|
346 |
-
if lora_repo and weights and adapter_name:
|
347 |
-
# 调用 pipe.load_lora_weights() 方法加载权重
|
348 |
-
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
349 |
-
adapter_names.append(adapter_name)
|
350 |
-
# 将 adapter_name 设置为模型的属性
|
351 |
-
setattr(pipe.transformer, 'adapter_name', adapter_name)
|
352 |
|
353 |
-
|
354 |
-
raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
|
355 |
-
adapter_weights = [lora_scale] * len(adapter_names)
|
356 |
-
# 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
|
357 |
-
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
358 |
|
359 |
-
#
|
360 |
-
if
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
|
363 |
# Set random seed for reproducibility
|
364 |
if randomize_seed:
|
365 |
with calculateDuration("Set random seed"):
|
366 |
seed = random.randint(0, MAX_SEED)
|
367 |
|
368 |
-
with calculateDuration("Encoding prompts"):
|
369 |
-
# 编码背景提示词
|
370 |
-
# 使用 tokenizer_2 和 text_encoder_2
|
371 |
-
bg_text_input_2 = pipe.tokenizer_2(prompt_bg, return_tensors="pt").to(device)
|
372 |
-
bg_prompt_embeds = pipe.text_encoder_2(bg_text_input_2.input_ids.to(device))[0]
|
373 |
-
|
374 |
-
# 使用 tokenizer 和 text_encoder
|
375 |
-
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
|
376 |
-
bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
|
377 |
-
|
378 |
-
# 编码角色提示词
|
379 |
-
character_prompt_embeds = []
|
380 |
-
character_pooled_embeds = []
|
381 |
-
for prompt in character_prompts:
|
382 |
-
# 使用 tokenizer_2 和 text_encoder_2
|
383 |
-
char_text_input_2 = pipe.tokenizer_2(prompt, return_tensors="pt").to(device)
|
384 |
-
char_prompt_embeds = pipe.text_encoder_2(char_text_input_2.input_ids.to(device))[0]
|
385 |
-
# 使用 tokenizer 和 text_encoder
|
386 |
-
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
|
387 |
-
char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
|
388 |
-
|
389 |
-
character_prompt_embeds.append(char_prompt_embeds)
|
390 |
-
character_pooled_embeds.append(char_pooled_embeds)
|
391 |
-
|
392 |
-
# 编码互动细节提示词
|
393 |
-
details_text_input_2 = pipe.tokenizer_2(prompt_details, return_tensors="pt").to(device)
|
394 |
-
details_prompt_embeds = pipe.text_encoder_2(details_text_input_2.input_ids.to(device))[0]
|
395 |
-
|
396 |
-
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
|
397 |
-
details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
|
398 |
-
|
399 |
-
# 合并背景和互动细节的嵌入
|
400 |
-
prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
|
401 |
-
pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=-1)
|
402 |
-
|
403 |
-
# 解析角色位置
|
404 |
-
character_infos = []
|
405 |
-
for position_str in character_positions:
|
406 |
-
info = parse_character_position(position_str)
|
407 |
-
character_infos.append(info)
|
408 |
-
|
409 |
-
# 创建角色的掩码
|
410 |
-
masks = []
|
411 |
-
for info in character_infos:
|
412 |
-
mask = create_attention_mask(width, height, info['location'], info['offset'], info['area'])
|
413 |
-
masks.append(mask)
|
414 |
-
|
415 |
-
# 替换注意力处理器
|
416 |
-
replace_attention_processors(pipe, masks, adapter_names)
|
417 |
-
|
418 |
# Generate image
|
419 |
-
final_image =
|
420 |
|
421 |
-
|
422 |
-
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
progress(100, "Completed!")
|
425 |
|
@@ -439,11 +152,9 @@ with gr.Blocks(css=css) as demo:
|
|
439 |
|
440 |
with gr.Column():
|
441 |
|
442 |
-
|
443 |
-
character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
|
444 |
-
character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)
|
445 |
lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
|
446 |
-
|
447 |
run_button = gr.Button("Run", scale=0)
|
448 |
|
449 |
with gr.Accordion("Advanced Settings", open=False):
|
@@ -474,11 +185,8 @@ with gr.Blocks(css=css) as demo:
|
|
474 |
json_text = gr.Text(label="Result JSON")
|
475 |
|
476 |
inputs = [
|
477 |
-
|
478 |
-
character_prompts,
|
479 |
-
character_positions,
|
480 |
lora_strings_json,
|
481 |
-
prompt_details,
|
482 |
cfg_scale,
|
483 |
steps,
|
484 |
randomize_seed,
|
|
|
14 |
# from diffusers.models.attention_processor import AttentionProcessor
|
15 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
16 |
import torch.nn.functional as F
|
17 |
+
import time
|
18 |
+
import boto3
|
19 |
+
from io import BytesIO
|
20 |
import re
|
21 |
import json
|
22 |
+
|
23 |
# 登录 Hugging Face Hub
|
24 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
25 |
login(token=HF_TOKEN)
|
|
|
52 |
else:
|
53 |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
# 生成图像的函数
|
56 |
+
@spaces.GPU
|
57 |
+
@torch.inference_mode()
|
58 |
+
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
59 |
pipe.to(device)
|
60 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
61 |
with calculateDuration("Generating image"):
|
62 |
# Generate image
|
63 |
generated_image = pipe(
|
64 |
+
prompt=prompt,
|
|
|
65 |
num_inference_steps=steps,
|
66 |
guidance_scale=cfg_scale,
|
67 |
width=width,
|
|
|
72 |
progress(99, "Generate success!")
|
73 |
return generated_image
|
74 |
|
75 |
+
|
76 |
+
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
77 |
+
print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
|
78 |
+
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
|
79 |
+
|
80 |
+
s3 = boto3.client(
|
81 |
+
's3',
|
82 |
+
endpoint_url=connectionUrl,
|
83 |
+
region_name='auto',
|
84 |
+
aws_access_key_id=access_key,
|
85 |
+
aws_secret_access_key=secret_key
|
86 |
+
)
|
87 |
+
|
88 |
+
current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
|
89 |
+
image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
|
90 |
+
buffer = BytesIO()
|
91 |
+
image.save(buffer, "PNG")
|
92 |
+
buffer.seek(0)
|
93 |
+
s3.upload_fileobj(buffer, bucket_name, image_file)
|
94 |
+
print("upload finish", image_file)
|
95 |
|
96 |
+
return image_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
# Load LoRA weights
|
101 |
+
if lora_strings_json:
|
102 |
+
try:
|
103 |
+
lora_strings_json = json.loads(lora_strings_json)
|
104 |
+
except:
|
105 |
+
lora_strings_json = None
|
106 |
+
if lora_strings_json:
|
107 |
+
with calculateDuration("Loading LoRA weights"):
|
108 |
+
pipe.unload_lora_weights()
|
109 |
+
adapter_names = []
|
110 |
+
for lora_info in lora_strings:
|
111 |
+
lora_repo = lora_info.get("repo")
|
112 |
+
weights = lora_info.get("weights")
|
113 |
+
adapter_name = lora_info.get("adapter_name")
|
114 |
+
if lora_repo and weights and adapter_name:
|
115 |
+
# 加载 LoRA 权重
|
116 |
+
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
117 |
+
adapter_names.append(adapter_name)
|
118 |
+
adapter_weights = [lora_scale] * len(adapter_names)
|
119 |
+
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
120 |
|
121 |
# Set random seed for reproducibility
|
122 |
if randomize_seed:
|
123 |
with calculateDuration("Set random seed"):
|
124 |
seed = random.randint(0, MAX_SEED)
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
# Generate image
|
127 |
+
final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
|
128 |
|
129 |
+
if final_image:
|
130 |
+
if upload_to_r2:
|
131 |
+
with calculateDuration("Upload image"):
|
132 |
+
url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket)
|
133 |
+
result = {"status": "success", "message": "upload image success", "url": url}
|
134 |
+
else:
|
135 |
+
result = {"status": "success", "message": "Image generated but not uploaded"}
|
136 |
|
137 |
progress(100, "Completed!")
|
138 |
|
|
|
152 |
|
153 |
with gr.Column():
|
154 |
|
155 |
+
prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=2)
|
|
|
|
|
156 |
lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
|
157 |
+
|
158 |
run_button = gr.Button("Run", scale=0)
|
159 |
|
160 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
185 |
json_text = gr.Text(label="Result JSON")
|
186 |
|
187 |
inputs = [
|
188 |
+
prompt,
|
|
|
|
|
189 |
lora_strings_json,
|
|
|
190 |
cfg_scale,
|
191 |
steps,
|
192 |
randomize_seed,
|