Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
import json | |
import logging | |
from diffusers import DiffusionPipeline | |
from huggingface_hub import login | |
import time | |
from datetime import datetime | |
from io import BytesIO | |
# from diffusers.models.attention_processor import AttentionProcessor | |
from diffusers.models.attention_processor import AttnProcessor2_0 | |
import torch.nn.functional as F | |
import re | |
import json | |
# 登录 Hugging Face Hub | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
login(token=HF_TOKEN) | |
import diffusers | |
print(diffusers.__version__) | |
# 初始化 | |
dtype = torch.float16 # 您可以根据需要调整数据类型 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型 | |
# 加载管道 | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device) | |
MAX_SEED = 2**32 - 1 | |
class calculateDuration: | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
if self.activity_name: | |
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
else: | |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
# 定义位置、偏移和区域的映射 | |
valid_locations = { # x, y in 90*90 | |
'in the center': (45, 45), | |
'on the left': (15, 45), | |
'on the right': (75, 45), | |
'on the top': (45, 15), | |
'on the bottom': (45, 75), | |
'on the top-left': (15, 15), | |
'on the top-right': (75, 15), | |
'on the bottom-left': (15, 75), | |
'on the bottom-right': (75, 75) | |
} | |
valid_offsets = { # x, y in 90*90 | |
'no offset': (0, 0), | |
'slightly to the left': (-10, 0), | |
'slightly to the right': (10, 0), | |
'slightly to the upper': (0, -10), | |
'slightly to the lower': (0, 10), | |
'slightly to the upper-left': (-10, -10), | |
'slightly to the upper-right': (10, -10), | |
'slightly to the lower-left': (-10, 10), | |
'slightly to the lower-right': (10, 10) | |
} | |
valid_areas = { # w, h in 90*90 | |
"a small square area": (50, 50), | |
"a small vertical area": (40, 60), | |
"a small horizontal area": (60, 40), | |
"a medium-sized square area": (60, 60), | |
"a medium-sized vertical area": (50, 80), | |
"a medium-sized horizontal area": (80, 50), | |
"a large square area": (70, 70), | |
"a large vertical area": (60, 90), | |
"a large horizontal area": (90, 60) | |
} | |
# 解析角色位置的函数 | |
def parse_character_position(character_position): | |
# 定义正则表达式模式 | |
location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys()) | |
offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys()) | |
area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys()) | |
# 提取位置 | |
location_match = re.search(location_pattern, character_position, re.IGNORECASE) | |
location = location_match.group(0) if location_match else 'in the center' | |
# 提取偏移 | |
offset_match = re.search(offset_pattern, character_position, re.IGNORECASE) | |
offset = offset_match.group(0) if offset_match else 'no offset' | |
# 提取区域 | |
area_match = re.search(area_pattern, character_position, re.IGNORECASE) | |
area = area_match.group(0) if area_match else 'a medium-sized square area' | |
return { | |
'location': location, | |
'offset': offset, | |
'area': area | |
} | |
# 创建掩码的函数 | |
def create_attention_mask(image_width, image_height, location, offset, area): | |
# 图像在生成时通常会被缩放为 90x90,因此先定义一个基础尺寸 | |
base_size = 90 | |
# 获取位置坐标 | |
loc_x, loc_y = valid_locations.get(location, (45, 45)) | |
# 获取偏移量 | |
offset_x, offset_y = valid_offsets.get(offset, (0, 0)) | |
# 获取区域大小 | |
area_width, area_height = valid_areas.get(area, (60, 60)) | |
# 计算最终位置 | |
final_x = loc_x + offset_x | |
final_y = loc_y + offset_y | |
# 将坐标和尺寸映射到实际图像尺寸 | |
scale_x = image_width / base_size | |
scale_y = image_height / base_size | |
center_x = final_x * scale_x | |
center_y = final_y * scale_y | |
width = area_width * scale_x | |
height = area_height * scale_y | |
# 计算左上角和右下角坐标 | |
x_start = int(max(center_x - width / 2, 0)) | |
y_start = int(max(center_y - height / 2, 0)) | |
x_end = int(min(center_x + width / 2, image_width)) | |
y_end = int(min(center_y + height / 2, image_height)) | |
# 创建掩码 | |
mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda") | |
mask[y_start:y_end, x_start:x_end] = 1.0 | |
# 展平成一维 | |
mask_flat = mask.view(-1) # 形状为 (image_height * image_width,) | |
return mask_flat | |
# 自定义注意力处理器 | |
class CustomCrossAttentionProcessor(AttnProcessor2_0): | |
def __init__(self, masks, adapter_names): | |
super().__init__() | |
self.masks = masks # 列表,包含每个角色的掩码 (shape: [key_length]) | |
self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称 | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
**kwargs, | |
): | |
""" | |
自定义的注意力处理器,用于在注意力计算中应用角色掩码。 | |
参数: | |
attn: 注意力模块实例。 | |
hidden_states: 输入的隐藏状态 (query)。 | |
encoder_hidden_states: 编码器的隐藏状态 (key/value)。 | |
attention_mask: 注意力掩码。 | |
temb: 时间嵌入(可能不需要)。 | |
**kwargs: 其他参数。 | |
返回: | |
处理后的隐藏状态。 | |
""" | |
# 获取当前的 adapter_name | |
adapter_name = getattr(attn, 'adapter_name', None) | |
if adapter_name is None or adapter_name not in self.adapter_names: | |
# 如果没有 adapter_name,或者不在我们的列表中,直接执行父类的 __call__ 方法 | |
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb, **kwargs) | |
# 查找 adapter_name 对应的索引 | |
idx = self.adapter_names.index(adapter_name) | |
mask = self.masks[idx] # 获取对应的掩码 (shape: [key_length]) | |
# 以下是 AttnProcessor2_0 的实现,我们在适当的位置加入自定义的掩码逻辑 | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
else: | |
batch_size, sequence_length, _ = hidden_states.shape | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
else: | |
# 如果有 encoder_hidden_states,获取其形状 | |
encoder_batch_size, key_length, _ = encoder_hidden_states.shape | |
if attention_mask is not None: | |
# 处理 attention_mask,如果需要的话 | |
attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) | |
# attention_mask 的形状应为 (batch_size, attn.heads, query_length, key_length) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
else: | |
# 如果没有 attention_mask,我们创建一个全 0 的掩码 | |
attention_mask = torch.zeros( | |
batch_size, attn.heads, 1, key_length, device=hidden_states.device, dtype=hidden_states.dtype | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
if attn.norm_q is not None: | |
query = attn.norm_q(query) | |
if attn.norm_k is not None: | |
key = attn.norm_k(key) | |
# 计算原始的注意力得分 | |
# 我们需要在计算注意力得分前应用掩码 | |
# 但由于 PyTorch 的 scaled_dot_product_attention 接受 attention_mask 参数,我们需要调整我们的掩码 | |
# 创建自定义的 attention_mask | |
# mask 的形状为 [key_length],需要调整为 (batch_size, attn.heads, 1, key_length) | |
custom_attention_mask = mask.view(1, 1, 1, -1).to(hidden_states.device, dtype=hidden_states.dtype) | |
# 将有效位置设为 0,被掩蔽的位置设为 -1e9(对于 float16,使用 -65504) | |
mask_value = -65504.0 if hidden_states.dtype == torch.float16 else -1e9 | |
custom_attention_mask = (1.0 - custom_attention_mask) * mask_value # 有效位置为 0,无效位置为 -1e9 | |
# 将自定义掩码添加到 attention_mask | |
attention_mask = attention_mask + custom_attention_mask | |
# 计算注意力 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
# 替换注意力处理器的函数 | |
def replace_attention_processors(pipe, masks, adapter_names): | |
custom_processor = CustomCrossAttentionProcessor(masks, adapter_names) | |
for name, module in pipe.transformer.named_modules(): | |
if hasattr(module, 'attn'): | |
module.attn.adapter_name = getattr(module, 'adapter_name', None) | |
module.attn.processor = custom_processor | |
if hasattr(module, 'cross_attn'): | |
module.cross_attn.adapter_name = getattr(module, 'adapter_name', None) | |
module.cross_attn.processor = custom_processor | |
# 生成图像的函数 | |
def generate_image_with_embeddings(prompt_embeds, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress): | |
pipe.to(device) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
with calculateDuration("Generating image"): | |
# Generate image | |
generated_image = pipe( | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
progress(99, "Generate success!") | |
return generated_image | |
# 主函数 | |
def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)): | |
# 解析角色提示词、位置和 LoRA 字符串 | |
try: | |
character_prompts = json.loads(character_prompts_json) | |
character_positions = json.loads(character_positions_json) | |
lora_strings = json.loads(lora_strings_json) | |
except json.JSONDecodeError as e: | |
raise ValueError(f"Invalid JSON input: {e}") | |
# 确保提示词、位置和 LoRA 字符串的数量一致 | |
if len(character_prompts) != len(character_positions) or len(character_prompts) != len(lora_strings): | |
raise ValueError("The number of character prompts, positions, and LoRA strings must be the same.") | |
# 角色的数量 | |
num_characters = len(character_prompts) | |
# Load LoRA weights | |
with calculateDuration("Loading LoRA weights"): | |
pipe.unload_lora_weights() | |
adapter_names = [] | |
for lora_info in lora_strings: | |
lora_repo = lora_info.get("repo") | |
weights = lora_info.get("weights") | |
adapter_name = lora_info.get("adapter_name") | |
if lora_repo and weights and adapter_name: | |
# 调用 pipe.load_lora_weights() 方法加载权重 | |
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name) | |
adapter_names.append(adapter_name) | |
# 将 adapter_name 设置为模型的属性 | |
setattr(pipe.transformer, 'adapter_name', adapter_name) | |
else: | |
raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.") | |
adapter_weights = [lora_scale] * len(adapter_names) | |
# 调用 pipeline.set_adapters 方法设置 adapter 和对应权重 | |
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) | |
# 确保 adapter_names 的数量与角色数量匹配 | |
if len(adapter_names) != num_characters: | |
raise ValueError("The number of LoRA adapters must match the number of characters.") | |
# Set random seed for reproducibility | |
if randomize_seed: | |
with calculateDuration("Set random seed"): | |
seed = random.randint(0, MAX_SEED) | |
# 编码提示词 | |
with calculateDuration("Encoding prompts"): | |
# 编码背景提示词 | |
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device) | |
bg_prompt_embeds = pipe.text_encoder_2(bg_text_input.input_ids.to(device))[0] | |
bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output | |
# 编码角色提示词 | |
character_prompt_embeds = [] | |
character_pooled_embeds = [] | |
for prompt in character_prompts: | |
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device) | |
char_prompt_embeds = pipe.text_encoder_2(char_text_input.input_ids.to(device))[0] | |
char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output | |
character_prompt_embeds.append(char_prompt_embeds) | |
character_pooled_embeds.append(char_pooled_embeds) | |
# 编码互动细节提示词 | |
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device) | |
details_prompt_embeds = pipe.text_encoder_2(details_text_input.input_ids.to(device))[0] | |
details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output | |
# 合并背景和互动细节的嵌入 | |
prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1) | |
pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=1) | |
# 解析角色位置 | |
character_infos = [] | |
for position_str in character_positions: | |
info = parse_character_position(position_str) | |
character_infos.append(info) | |
# 创建角色的掩码 | |
masks = [] | |
for info in character_infos: | |
mask = create_attention_mask(width, height, info['location'], info['offset'], info['area']) | |
masks.append(mask) | |
# 替换注意力处理器 | |
replace_attention_processors(pipe, masks, adapter_names) | |
# Generate image | |
final_image = generate_image_with_embeddings(prompt_embeddings, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress) | |
# 您可以在此处添加上传图片的代码 | |
result = {"status": "success", "message": "Image generated"} | |
progress(100, "Completed!") | |
return final_image, seed, json.dumps(result) | |
# Gradio 界面 | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("Flux with LoRA") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2) | |
character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5) | |
character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5) | |
lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5) | |
prompt_details = gr.Text(label="Interaction Details", placeholder="Enter interaction details between characters", lines=2) | |
run_button = gr.Button("Run", scale=0) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.5) | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=512) | |
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=512) | |
with gr.Row(): | |
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) | |
upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False) | |
account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id") | |
access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here") | |
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here") | |
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here") | |
with gr.Column(): | |
result = gr.Image(label="Result", show_label=False) | |
seed_output = gr.Text(label="Seed") | |
json_text = gr.Text(label="Result JSON") | |
inputs = [ | |
prompt_bg, | |
character_prompts, | |
character_positions, | |
lora_strings_json, | |
prompt_details, | |
cfg_scale, | |
steps, | |
randomize_seed, | |
seed, | |
width, | |
height, | |
lora_scale, | |
upload_to_r2, | |
account_id, | |
access_key, | |
secret_key, | |
bucket | |
] | |
outputs = [result, seed_output, json_text] | |
run_button.click( | |
fn=run_lora, | |
inputs=inputs, | |
outputs=outputs | |
) | |
demo.queue().launch() | |