Spaces:
Running
on
Zero
Running
on
Zero
bugfix
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import random
|
|
|
5 |
import torch
|
6 |
import json
|
7 |
import logging
|
@@ -10,12 +11,17 @@ from huggingface_hub import login
|
|
10 |
import time
|
11 |
from datetime import datetime
|
12 |
from io import BytesIO
|
13 |
-
from diffusers.models.attention_processor import AttentionProcessor
|
|
|
|
|
|
|
14 |
import re
|
15 |
import json
|
16 |
# 登录 Hugging Face Hub
|
17 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
18 |
login(token=HF_TOKEN)
|
|
|
|
|
19 |
|
20 |
# 初始化
|
21 |
dtype = torch.float16 # 您可以根据需要调整数据类型
|
@@ -145,79 +151,160 @@ def create_attention_mask(image_width, image_height, location, offset, area):
|
|
145 |
return mask_flat
|
146 |
|
147 |
# 自定义注意力处理器
|
148 |
-
|
149 |
-
|
|
|
150 |
super().__init__()
|
151 |
-
self.masks = masks # 列表,包含每个角色的掩码
|
152 |
-
self.embeddings = embeddings # 列表,包含每个角色的嵌入
|
153 |
self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称
|
154 |
|
155 |
-
def __call__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
# 获取当前的 adapter_name
|
157 |
adapter_name = getattr(attn, 'adapter_name', None)
|
158 |
if adapter_name is None or adapter_name not in self.adapter_names:
|
159 |
-
# 如果没有 adapter_name
|
160 |
-
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
161 |
|
162 |
# 查找 adapter_name 对应的索引
|
163 |
idx = self.adapter_names.index(adapter_name)
|
164 |
-
mask = self.masks[idx]
|
165 |
-
|
166 |
-
#
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
query = attn.to_q(hidden_states)
|
|
|
|
|
|
|
|
|
170 |
key = attn.to_k(encoder_hidden_states)
|
171 |
value = attn.to_v(encoder_hidden_states)
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
#
|
179 |
-
|
180 |
|
181 |
-
#
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
# 将掩码应用于 attention_scores
|
186 |
-
attention_scores += mask_expanded * 1e6 # 增强对应位置的注意力
|
187 |
|
188 |
-
|
189 |
-
|
190 |
|
191 |
-
#
|
192 |
-
|
|
|
|
|
193 |
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
# 输出投影
|
198 |
-
hidden_states = attn.to_out(context)
|
199 |
return hidden_states
|
200 |
|
|
|
201 |
# 替换注意力处理器的函数
|
202 |
-
def replace_attention_processors(pipe, masks,
|
203 |
-
custom_processor = CustomCrossAttentionProcessor(masks,
|
204 |
-
for name, module in pipe.
|
205 |
-
if hasattr(module, '
|
206 |
-
|
207 |
-
module.
|
208 |
-
|
|
|
|
|
209 |
|
210 |
# 生成图像的函数
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
generator = torch.Generator(device="cuda").manual_seed(seed)
|
216 |
|
217 |
with calculateDuration("Generating image"):
|
218 |
# Generate image
|
219 |
generated_image = pipe(
|
220 |
-
prompt_embeds=
|
|
|
221 |
num_inference_steps=steps,
|
222 |
guidance_scale=cfg_scale,
|
223 |
width=width,
|
@@ -229,7 +316,8 @@ def generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, wi
|
|
229 |
return generated_image
|
230 |
|
231 |
# 主函数
|
232 |
-
|
|
|
233 |
def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
234 |
|
235 |
# 解析角色提示词、位置和 LoRA 字符串
|
@@ -260,7 +348,8 @@ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_s
|
|
260 |
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
261 |
adapter_names.append(adapter_name)
|
262 |
# 将 adapter_name 设置为模型的属性
|
263 |
-
setattr(pipe.
|
|
|
264 |
else:
|
265 |
raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
|
266 |
adapter_weights = [lora_scale] * len(adapter_names)
|
@@ -279,22 +368,28 @@ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_s
|
|
279 |
# 编码提示词
|
280 |
with calculateDuration("Encoding prompts"):
|
281 |
# 编码背景提示词
|
282 |
-
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(
|
283 |
-
|
284 |
-
|
|
|
285 |
# 编码角色提示词
|
286 |
-
|
|
|
287 |
for prompt in character_prompts:
|
288 |
-
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
292 |
# 编码互动细节提示词
|
293 |
-
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(
|
294 |
-
|
295 |
-
|
|
|
296 |
# 合并背景和互动细节的嵌入
|
297 |
-
|
|
|
298 |
|
299 |
# 解析角色位置
|
300 |
character_infos = []
|
@@ -309,10 +404,10 @@ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_s
|
|
309 |
masks.append(mask)
|
310 |
|
311 |
# 替换注意力处理器
|
312 |
-
replace_attention_processors(pipe, masks,
|
313 |
|
314 |
# Generate image
|
315 |
-
final_image = generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress)
|
316 |
|
317 |
# 您可以在此处添加上传图片的代码
|
318 |
result = {"status": "success", "message": "Image generated"}
|
@@ -334,7 +429,7 @@ with gr.Blocks(css=css) as demo:
|
|
334 |
with gr.Row():
|
335 |
|
336 |
with gr.Column():
|
337 |
-
|
338 |
prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2)
|
339 |
character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
|
340 |
character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)
|
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import random
|
5 |
+
import spaces
|
6 |
import torch
|
7 |
import json
|
8 |
import logging
|
|
|
11 |
import time
|
12 |
from datetime import datetime
|
13 |
from io import BytesIO
|
14 |
+
# from diffusers.models.attention_processor import AttentionProcessor
|
15 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
import re
|
19 |
import json
|
20 |
# 登录 Hugging Face Hub
|
21 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
22 |
login(token=HF_TOKEN)
|
23 |
+
import diffusers
|
24 |
+
print(diffusers.__version__)
|
25 |
|
26 |
# 初始化
|
27 |
dtype = torch.float16 # 您可以根据需要调整数据类型
|
|
|
151 |
return mask_flat
|
152 |
|
153 |
# 自定义注意力处理器
|
154 |
+
|
155 |
+
class CustomCrossAttentionProcessor(AttnProcessor2_0):
|
156 |
+
def __init__(self, masks, adapter_names):
|
157 |
super().__init__()
|
158 |
+
self.masks = masks # 列表,包含每个角色的掩码 (shape: [key_length])
|
|
|
159 |
self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称
|
160 |
|
161 |
+
def __call__(
|
162 |
+
self,
|
163 |
+
attn,
|
164 |
+
hidden_states,
|
165 |
+
encoder_hidden_states=None,
|
166 |
+
attention_mask=None,
|
167 |
+
temb=None,
|
168 |
+
**kwargs,
|
169 |
+
):
|
170 |
+
"""
|
171 |
+
自定义的注意力处理器,用于在注意力计算中应用角色掩码。
|
172 |
+
|
173 |
+
参数:
|
174 |
+
attn: 注意力模块实例。
|
175 |
+
hidden_states: 输入的隐藏状态 (query)。
|
176 |
+
encoder_hidden_states: 编码器的隐藏状态 (key/value)。
|
177 |
+
attention_mask: 注意力掩码。
|
178 |
+
temb: 时间嵌入(可能不需要)。
|
179 |
+
**kwargs: 其他参数。
|
180 |
+
|
181 |
+
返回:
|
182 |
+
处理后的隐藏状态。
|
183 |
+
"""
|
184 |
# 获取当前的 adapter_name
|
185 |
adapter_name = getattr(attn, 'adapter_name', None)
|
186 |
if adapter_name is None or adapter_name not in self.adapter_names:
|
187 |
+
# 如果没有 adapter_name,或者不在我们的列表中,直接执行父类的 __call__ 方法
|
188 |
+
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb, **kwargs)
|
189 |
|
190 |
# 查找 adapter_name 对应的索引
|
191 |
idx = self.adapter_names.index(adapter_name)
|
192 |
+
mask = self.masks[idx] # 获取对应的掩码 (shape: [key_length])
|
193 |
+
|
194 |
+
# 以下是 AttnProcessor2_0 的实现,我们在适当的位置加入自定义的掩码逻辑
|
195 |
+
|
196 |
+
residual = hidden_states
|
197 |
+
if attn.spatial_norm is not None:
|
198 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
199 |
+
|
200 |
+
input_ndim = hidden_states.ndim
|
201 |
+
|
202 |
+
if input_ndim == 4:
|
203 |
+
batch_size, channel, height, width = hidden_states.shape
|
204 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
205 |
+
else:
|
206 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
207 |
+
|
208 |
+
if encoder_hidden_states is None:
|
209 |
+
encoder_hidden_states = hidden_states
|
210 |
+
else:
|
211 |
+
# 如果有 encoder_hidden_states,获取其形状
|
212 |
+
encoder_batch_size, key_length, _ = encoder_hidden_states.shape
|
213 |
+
|
214 |
+
if attention_mask is not None:
|
215 |
+
# 处理 attention_mask,如果需要的话
|
216 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
|
217 |
+
# attention_mask 的形状应为 (batch_size, attn.heads, query_length, key_length)
|
218 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
219 |
+
else:
|
220 |
+
# 如果没有 attention_mask,我们创建一个全 0 的掩码
|
221 |
+
attention_mask = torch.zeros(
|
222 |
+
batch_size, attn.heads, 1, key_length, device=hidden_states.device, dtype=hidden_states.dtype
|
223 |
+
)
|
224 |
+
|
225 |
+
if attn.group_norm is not None:
|
226 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
227 |
|
228 |
query = attn.to_q(hidden_states)
|
229 |
+
|
230 |
+
if attn.norm_cross:
|
231 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
232 |
+
|
233 |
key = attn.to_k(encoder_hidden_states)
|
234 |
value = attn.to_v(encoder_hidden_states)
|
235 |
|
236 |
+
inner_dim = key.shape[-1]
|
237 |
+
head_dim = inner_dim // attn.heads
|
238 |
+
|
239 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
240 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
241 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
242 |
+
|
243 |
+
if attn.norm_q is not None:
|
244 |
+
query = attn.norm_q(query)
|
245 |
+
if attn.norm_k is not None:
|
246 |
+
key = attn.norm_k(key)
|
247 |
+
|
248 |
+
# 计算原始的注意力得分
|
249 |
+
# 我们需要在计算注意力得分前应用掩码
|
250 |
+
# 但由于 PyTorch 的 scaled_dot_product_attention 接受 attention_mask 参数,我们需要调整我们的掩码
|
251 |
+
|
252 |
+
# 创建自定义的 attention_mask
|
253 |
+
# mask 的形状为 [key_length],需要调整为 (batch_size, attn.heads, 1, key_length)
|
254 |
+
custom_attention_mask = mask.view(1, 1, 1, -1).to(hidden_states.device, dtype=hidden_states.dtype)
|
255 |
+
# 将有效位置设为 0,被掩蔽的位置设为 -1e9(对于 float16,使用 -65504)
|
256 |
+
mask_value = -65504.0 if hidden_states.dtype == torch.float16 else -1e9
|
257 |
+
custom_attention_mask = (1.0 - custom_attention_mask) * mask_value # 有效位置为 0,无效位置为 -1e9
|
258 |
|
259 |
+
# 将自定义掩码添加到 attention_mask
|
260 |
+
attention_mask = attention_mask + custom_attention_mask
|
261 |
|
262 |
+
# 计算注意力
|
263 |
+
hidden_states = F.scaled_dot_product_attention(
|
264 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
265 |
+
)
|
|
|
|
|
266 |
|
267 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
268 |
+
hidden_states = hidden_states.to(query.dtype)
|
269 |
|
270 |
+
# linear proj
|
271 |
+
hidden_states = attn.to_out[0](hidden_states)
|
272 |
+
# dropout
|
273 |
+
hidden_states = attn.to_out[1](hidden_states)
|
274 |
|
275 |
+
if input_ndim == 4:
|
276 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
277 |
+
|
278 |
+
if attn.residual_connection:
|
279 |
+
hidden_states = hidden_states + residual
|
280 |
+
|
281 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
282 |
|
|
|
|
|
283 |
return hidden_states
|
284 |
|
285 |
+
|
286 |
# 替换注意力处理器的函数
|
287 |
+
def replace_attention_processors(pipe, masks, adapter_names):
|
288 |
+
custom_processor = CustomCrossAttentionProcessor(masks, adapter_names)
|
289 |
+
for name, module in pipe.transformer.named_modules():
|
290 |
+
if hasattr(module, 'attn'):
|
291 |
+
module.attn.adapter_name = getattr(module, 'adapter_name', None)
|
292 |
+
module.attn.processor = custom_processor
|
293 |
+
if hasattr(module, 'cross_attn'):
|
294 |
+
module.cross_attn.adapter_name = getattr(module, 'adapter_name', None)
|
295 |
+
module.cross_attn.processor = custom_processor
|
296 |
|
297 |
# 生成图像的函数
|
298 |
+
|
299 |
+
def generate_image_with_embeddings(prompt_embeds, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress):
|
300 |
+
pipe.to(device)
|
301 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
302 |
|
303 |
with calculateDuration("Generating image"):
|
304 |
# Generate image
|
305 |
generated_image = pipe(
|
306 |
+
prompt_embeds=prompt_embeds,
|
307 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
308 |
num_inference_steps=steps,
|
309 |
guidance_scale=cfg_scale,
|
310 |
width=width,
|
|
|
316 |
return generated_image
|
317 |
|
318 |
# 主函数
|
319 |
+
@spaces.GPU
|
320 |
+
@torch.inference_mode()
|
321 |
def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
322 |
|
323 |
# 解析角色提示词、位置和 LoRA 字符串
|
|
|
348 |
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
349 |
adapter_names.append(adapter_name)
|
350 |
# 将 adapter_name 设置为模型的属性
|
351 |
+
setattr(pipe.transformer, 'adapter_name', adapter_name)
|
352 |
+
|
353 |
else:
|
354 |
raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
|
355 |
adapter_weights = [lora_scale] * len(adapter_names)
|
|
|
368 |
# 编码提示词
|
369 |
with calculateDuration("Encoding prompts"):
|
370 |
# 编码背景提示词
|
371 |
+
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
|
372 |
+
bg_prompt_embeds = pipe.text_encoder_2(bg_text_input.input_ids.to(device))[0]
|
373 |
+
bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
|
374 |
+
|
375 |
# 编码角色提示词
|
376 |
+
character_prompt_embeds = []
|
377 |
+
character_pooled_embeds = []
|
378 |
for prompt in character_prompts:
|
379 |
+
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
|
380 |
+
char_prompt_embeds = pipe.text_encoder_2(char_text_input.input_ids.to(device))[0]
|
381 |
+
char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
|
382 |
+
character_prompt_embeds.append(char_prompt_embeds)
|
383 |
+
character_pooled_embeds.append(char_pooled_embeds)
|
384 |
+
|
385 |
# 编码互动细节提示词
|
386 |
+
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
|
387 |
+
details_prompt_embeds = pipe.text_encoder_2(details_text_input.input_ids.to(device))[0]
|
388 |
+
details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
|
389 |
+
|
390 |
# 合并背景和互动细节的嵌入
|
391 |
+
prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
|
392 |
+
pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=1)
|
393 |
|
394 |
# 解析角色位置
|
395 |
character_infos = []
|
|
|
404 |
masks.append(mask)
|
405 |
|
406 |
# 替换注意力处理器
|
407 |
+
replace_attention_processors(pipe, masks, adapter_names)
|
408 |
|
409 |
# Generate image
|
410 |
+
final_image = generate_image_with_embeddings(prompt_embeddings, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress)
|
411 |
|
412 |
# 您可以在此处添加上传图片的代码
|
413 |
result = {"status": "success", "message": "Image generated"}
|
|
|
429 |
with gr.Row():
|
430 |
|
431 |
with gr.Column():
|
432 |
+
|
433 |
prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2)
|
434 |
character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
|
435 |
character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)
|