jiuface commited on
Commit
f93e467
1 Parent(s): 53d0f2f
Files changed (1) hide show
  1. app.py +303 -146
app.py CHANGED
@@ -2,51 +2,30 @@ import os
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
- import spaces
6
- from diffusers import DiffusionPipeline
7
  import torch
8
  import json
9
  import logging
10
- from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
11
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
12
  from huggingface_hub import login
13
- from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
14
- import copy
15
- import random
16
  import time
17
- import boto3
18
- from io import BytesIO
19
  from datetime import datetime
20
- from transformers import AutoTokenizer
21
-
22
- from diffusers import UNet2DConditionModel
23
-
24
-
25
-
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
-
28
  login(token=HF_TOKEN)
29
 
30
- # init
31
- dtype = torch.bfloat16
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
- base_model = "black-forest-labs/FLUX.1-dev"
34
-
35
- # unet = UNet2DConditionModel.from_pretrained(
36
- # base_model,
37
- # torch_dtype=torch.float16,
38
- # use_safetensors=True,
39
- # variant="fp16",
40
- # subfolder="unet",
41
- # # ).to("cuda")
42
- # tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
43
-
44
 
 
45
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
46
 
47
-
48
-
49
- MAX_SEED = 2**32-1
50
 
51
  class calculateDuration:
52
  def __init__(self, activity_name=""):
@@ -64,133 +43,285 @@ class calculateDuration:
64
  else:
65
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
66
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
69
- print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
70
- connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
71
-
72
- s3 = boto3.client(
73
- 's3',
74
- endpoint_url=connectionUrl,
75
- region_name='auto',
76
- aws_access_key_id=access_key,
77
- aws_secret_access_key=secret_key
78
- )
79
-
80
- current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
81
- image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
82
- buffer = BytesIO()
83
- image.save(buffer, "PNG")
84
- buffer.seek(0)
85
- s3.upload_fileobj(buffer, bucket_name, image_file)
86
- print("upload finish", image_file)
87
- return image_file
88
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- @spaces.GPU
91
- def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
92
- pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- text_inputs = pipe.tokenizer(prompt, return_tensors="pt").to("cuda")
95
- input_ids = text_inputs.input_ids[0]
96
-
97
- # 获取每个主体对应的令牌 ID
98
- boy_token_id = pipe.tokenizer.convert_tokens_to_ids("boy_asia_05")
99
- print(boy_token_id)
100
- girl_token_id = pipe.tokenizer.convert_tokens_to_ids("girl_asia_04")
101
- print(girl_token_id)
102
- # 找到每个主体在输入中的索引位置
103
- boy_indices = (input_ids == boy_token_id).nonzero(as_tuple=True)[0]
104
- girl_indices = (input_ids == girl_token_id).nonzero(as_tuple=True)[0]
105
-
106
- # 准备 cross_attention_kwargs
107
- def attention_control(attention_probs, adapter_name):
108
- # 根据 adapter_name 和令牌索引控制注意力
109
- print("attention_control", adapter_name)
110
- if adapter_name == "boy_asia_05":
111
- # 对女孩的令牌注意力设为零
112
- attention_probs[:, :, :, girl_indices] = 0
113
- elif adapter_name == "girl_asia_04":
114
- # 对男孩的令牌注意力设为零
115
- attention_probs[:, :, :, boy_indices] = 0
116
- return attention_probs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- joint_attention_kwargs = {"attention_control": attention_control}
 
 
119
 
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  generator = torch.Generator(device="cuda").manual_seed(seed)
 
122
  with calculateDuration("Generating image"):
123
  # Generate image
124
- generate_image = pipe(
125
- prompt=prompt,
126
  num_inference_steps=steps,
127
  guidance_scale=cfg_scale,
128
  width=width,
129
  height=height,
130
  generator=generator,
131
- joint_attention_kwargs=joint_attention_kwargs
132
  ).images[0]
133
-
134
  progress(99, "Generate success!")
135
- return generate_image
136
-
137
- # 在 Transformer 中,自定义注意力处理器
138
- class CustomAttentionProcessor(torch.nn.Module):
139
- def __init__(self, attention_control, adapter_name):
140
- super().__init__()
141
- self.attention_control = attention_control
142
- self.adapter_name = adapter_name
143
-
144
- def forward(self, attention_probs):
145
- # 调用自定义的注意力控制函数
146
- attention_probs = self.attention_control(attention_probs, self.adapter_name)
147
- return attention_probs
148
 
 
149
 
150
-
151
- def run_lora(prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
-
154
  # Load LoRA weights
155
- if lora_strings:
156
- with calculateDuration(f"Loading LoRA weights for {lora_strings}"):
157
- pipe.unload_lora_weights()
158
- lora_array = lora_strings.split(',')
159
- adapter_names = []
160
- for lora_string in lora_array:
161
- parts = lora_string.split(':')
162
- if len(parts) == 3:
163
- lora_repo, weights, adapter_name = parts
164
- # 调用 pipe.load_lora_weights() 方法加载权重
165
- pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
166
- adapter_names.append(adapter_name)
167
- else:
168
- print(f"Invalid format for lora_string: {lora_string}")
169
-
170
- adapter_weights = [lora_scale] * len(adapter_names)
171
- # 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
172
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
173
-
174
-
175
-
 
 
176
  # Set random seed for reproducibility
177
  if randomize_seed:
178
- seed = random.randint(0, MAX_SEED)
179
-
180
- final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
181
 
182
- if upload_to_r2:
183
- with calculateDuration("upload r2"):
184
- url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket)
185
- result = {"status": "success", "url": url}
186
- else:
187
- result = {"status": "success", "message": "Image generated but not uploaded"}
188
 
189
- progress(100, "Completed!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- yield final_image, seed, json.dumps(result)
192
 
 
193
 
 
194
  css="""
195
  #col-container {
196
  margin: 0 auto;
@@ -199,12 +330,16 @@ css="""
199
  """
200
 
201
  with gr.Blocks(css=css) as demo:
202
- gr.Markdown("Flux with lora")
203
  with gr.Row():
204
 
205
  with gr.Column():
206
- prompt = gr.Text(label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False)
207
- lora_strings = gr.Text( label="lora_strings", max_lines=1, placeholder="Enter a lora strings", visible=True)
 
 
 
 
208
  run_button = gr.Button("Run", scale=0)
209
 
210
  with gr.Accordion("Advanced Settings", open=False):
@@ -215,11 +350,11 @@ with gr.Blocks(css=css) as demo:
215
  lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.5)
216
 
217
  with gr.Row():
218
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
219
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
220
 
221
  with gr.Row():
222
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
223
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
224
 
225
  upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False)
@@ -231,13 +366,35 @@ with gr.Blocks(css=css) as demo:
231
 
232
  with gr.Column():
233
  result = gr.Image(label="Result", show_label=False)
234
- json_text = gr.Text()
235
-
236
- gr.on(
237
- triggers=[run_button.click, prompt.submit],
238
- fn = run_lora,
239
- inputs = [prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket],
240
- outputs=[result, seed, json_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  )
242
 
243
- demo.queue().launch()
 
2
  import gradio as gr
3
  import numpy as np
4
  import random
 
 
5
  import torch
6
  import json
7
  import logging
8
+ from diffusers import DiffusionPipeline
 
9
  from huggingface_hub import login
 
 
 
10
  import time
 
 
11
  from datetime import datetime
12
+ from io import BytesIO
13
+ from diffusers.models.attention_processor import AttentionProcessor
14
+ import re
15
+ import json
16
+ # 登录 Hugging Face Hub
 
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
18
  login(token=HF_TOKEN)
19
 
20
+ # 初始化
21
+ dtype = torch.float16 # 您可以根据需要调整数据类型
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # 加载管道
26
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
27
 
28
+ MAX_SEED = 2**32 - 1
 
 
29
 
30
  class calculateDuration:
31
  def __init__(self, activity_name=""):
 
43
  else:
44
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
45
 
46
+ # 定义位置、偏移和区域的映射
47
+ valid_locations = { # x, y in 90*90
48
+ 'in the center': (45, 45),
49
+ 'on the left': (15, 45),
50
+ 'on the right': (75, 45),
51
+ 'on the top': (45, 15),
52
+ 'on the bottom': (45, 75),
53
+ 'on the top-left': (15, 15),
54
+ 'on the top-right': (75, 15),
55
+ 'on the bottom-left': (15, 75),
56
+ 'on the bottom-right': (75, 75)
57
+ }
58
 
59
+ valid_offsets = { # x, y in 90*90
60
+ 'no offset': (0, 0),
61
+ 'slightly to the left': (-10, 0),
62
+ 'slightly to the right': (10, 0),
63
+ 'slightly to the upper': (0, -10),
64
+ 'slightly to the lower': (0, 10),
65
+ 'slightly to the upper-left': (-10, -10),
66
+ 'slightly to the upper-right': (10, -10),
67
+ 'slightly to the lower-left': (-10, 10),
68
+ 'slightly to the lower-right': (10, 10)
69
+ }
 
 
 
 
 
 
 
 
 
70
 
71
+ valid_areas = { # w, h in 90*90
72
+ "a small square area": (50, 50),
73
+ "a small vertical area": (40, 60),
74
+ "a small horizontal area": (60, 40),
75
+ "a medium-sized square area": (60, 60),
76
+ "a medium-sized vertical area": (50, 80),
77
+ "a medium-sized horizontal area": (80, 50),
78
+ "a large square area": (70, 70),
79
+ "a large vertical area": (60, 90),
80
+ "a large horizontal area": (90, 60)
81
+ }
82
 
83
+ # 解析角色位置的函数
84
+ def parse_character_position(character_position):
85
+ # 定义正则表达式模式
86
+ location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys())
87
+ offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys())
88
+ area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys())
89
+
90
+ # 提取位置
91
+ location_match = re.search(location_pattern, character_position, re.IGNORECASE)
92
+ location = location_match.group(0) if location_match else 'in the center'
93
+
94
+ # 提取偏移
95
+ offset_match = re.search(offset_pattern, character_position, re.IGNORECASE)
96
+ offset = offset_match.group(0) if offset_match else 'no offset'
97
+
98
+ # 提取区域
99
+ area_match = re.search(area_pattern, character_position, re.IGNORECASE)
100
+ area = area_match.group(0) if area_match else 'a medium-sized square area'
101
+
102
+ return {
103
+ 'location': location,
104
+ 'offset': offset,
105
+ 'area': area
106
+ }
107
 
108
+ # 创建掩码的函数
109
+ def create_attention_mask(image_width, image_height, location, offset, area):
110
+ # 图像在生成时通常会被缩放为 90x90,因此先定义一个基础尺寸
111
+ base_size = 90
112
+
113
+ # 获取位置坐标
114
+ loc_x, loc_y = valid_locations.get(location, (45, 45))
115
+ # 获取偏移量
116
+ offset_x, offset_y = valid_offsets.get(offset, (0, 0))
117
+ # 获取区域大小
118
+ area_width, area_height = valid_areas.get(area, (60, 60))
119
+
120
+ # 计算最终位置
121
+ final_x = loc_x + offset_x
122
+ final_y = loc_y + offset_y
123
+
124
+ # 将坐标和尺寸映射到实际图像尺寸
125
+ scale_x = image_width / base_size
126
+ scale_y = image_height / base_size
127
+
128
+ center_x = final_x * scale_x
129
+ center_y = final_y * scale_y
130
+ width = area_width * scale_x
131
+ height = area_height * scale_y
132
+
133
+ # 计算左上角和右下角坐标
134
+ x_start = int(max(center_x - width / 2, 0))
135
+ y_start = int(max(center_y - height / 2, 0))
136
+ x_end = int(min(center_x + width / 2, image_width))
137
+ y_end = int(min(center_y + height / 2, image_height))
138
+
139
+ # 创建掩码
140
+ mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda")
141
+ mask[y_start:y_end, x_start:x_end] = 1.0
142
+
143
+ # 展平成一维
144
+ mask_flat = mask.view(-1) # 形状为 (image_height * image_width,)
145
+ return mask_flat
146
+
147
+ # 自定义注意力处理器
148
+ class CustomCrossAttentionProcessor(AttentionProcessor):
149
+ def __init__(self, masks, embeddings, adapter_names):
150
+ super().__init__()
151
+ self.masks = masks # 列表,包含每个角色的掩码
152
+ self.embeddings = embeddings # 列表,包含每个角色的嵌入
153
+ self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称
154
+
155
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
156
+ # 获取当前的 adapter_name
157
+ adapter_name = getattr(attn, 'adapter_name', None)
158
+ if adapter_name is None or adapter_name not in self.adapter_names:
159
+ # 如果没有 adapter_name,直接执行默认的注意力计算
160
+ return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
161
 
162
+ # 查找 adapter_name 对应的索引
163
+ idx = self.adapter_names.index(adapter_name)
164
+ mask = self.masks[idx]
165
 
166
+ # 标准的注意力计算
167
+ batch_size, sequence_length, _ = hidden_states.shape
168
+
169
+ query = attn.to_q(hidden_states)
170
+ key = attn.to_k(encoder_hidden_states)
171
+ value = attn.to_v(encoder_hidden_states)
172
+
173
+ # 重塑以适应多头注意力
174
+ query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
175
+ key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
176
+ value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
177
+
178
+ # 计算注意力得分
179
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) * attn.scale
180
+
181
+ # 应用掩码调整注意力得分
182
+ # 将 mask 调整为与 attention_scores 兼容的形状
183
+ # 假设 key_len 与 mask 的长度一致
184
+ mask_expanded = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, key_len)
185
+ # 将掩码应用于 attention_scores
186
+ attention_scores += mask_expanded * 1e6 # 增强对应位置的注意力
187
+
188
+ # 计算注意力概率
189
+ attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
190
+
191
+ # 计算上下文向量
192
+ context = torch.matmul(attention_probs, value)
193
+
194
+ # 重塑回原始形状
195
+ context = context.transpose(1, 2).reshape(batch_size, -1, attn.heads * attn.head_dim)
196
+
197
+ # 输出投影
198
+ hidden_states = attn.to_out(context)
199
+ return hidden_states
200
+
201
+ # 替换注意力处理器的函数
202
+ def replace_attention_processors(pipe, masks, embeddings, adapter_names):
203
+ custom_processor = CustomCrossAttentionProcessor(masks, embeddings, adapter_names)
204
+ for name, module in pipe.unet.named_modules():
205
+ if hasattr(module, 'attn2'):
206
+ # 设置 adapter_name 为模块的属性
207
+ module.attn2.adapter_name = getattr(module, 'adapter_name', None)
208
+ module.attn2.processor = custom_processor
209
+
210
+ # 生成图像的函数
211
+ @spaces.GPU
212
+ @torch.inference_mode()
213
+ def generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress):
214
+ pipe.to("cuda")
215
  generator = torch.Generator(device="cuda").manual_seed(seed)
216
+
217
  with calculateDuration("Generating image"):
218
  # Generate image
219
+ generated_image = pipe(
220
+ prompt_embeds=prompt_embeddings,
221
  num_inference_steps=steps,
222
  guidance_scale=cfg_scale,
223
  width=width,
224
  height=height,
225
  generator=generator,
 
226
  ).images[0]
227
+
228
  progress(99, "Generate success!")
229
+ return generated_image
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ # 主函数
232
 
233
+ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
234
+
235
+ # 解析角色提示词、位置和 LoRA 字符串
236
+ try:
237
+ character_prompts = json.loads(character_prompts_json)
238
+ character_positions = json.loads(character_positions_json)
239
+ lora_strings = json.loads(lora_strings_json)
240
+ except json.JSONDecodeError as e:
241
+ raise ValueError(f"Invalid JSON input: {e}")
242
+
243
+ # 确保提示词、位置和 LoRA 字符串的数量一致
244
+ if len(character_prompts) != len(character_positions) or len(character_prompts) != len(lora_strings):
245
+ raise ValueError("The number of character prompts, positions, and LoRA strings must be the same.")
246
+
247
+ # 角色的数量
248
+ num_characters = len(character_prompts)
249
 
 
250
  # Load LoRA weights
251
+ with calculateDuration("Loading LoRA weights"):
252
+ pipe.unload_lora_weights()
253
+ adapter_names = []
254
+ for lora_info in lora_strings:
255
+ lora_repo = lora_info.get("repo")
256
+ weights = lora_info.get("weights")
257
+ adapter_name = lora_info.get("adapter_name")
258
+ if lora_repo and weights and adapter_name:
259
+ # 调用 pipe.load_lora_weights() 方法加载权重
260
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
261
+ adapter_names.append(adapter_name)
262
+ # 将 adapter_name 设置为模型的属性
263
+ setattr(pipe.unet, 'adapter_name', adapter_name)
264
+ else:
265
+ raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
266
+ adapter_weights = [lora_scale] * len(adapter_names)
267
+ # 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
268
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
269
+
270
+ # 确保 adapter_names 的数量与角色数量匹配
271
+ if len(adapter_names) != num_characters:
272
+ raise ValueError("The number of LoRA adapters must match the number of characters.")
273
+
274
  # Set random seed for reproducibility
275
  if randomize_seed:
276
+ with calculateDuration("Set random seed"):
277
+ seed = random.randint(0, MAX_SEED)
 
278
 
279
+ # 编码提示词
280
+ with calculateDuration("Encoding prompts"):
281
+ # 编码背景提示词
282
+ bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to("cuda")
283
+ bg_embeddings = pipe.text_encoder(bg_text_input.input_ids.to(device))[0]
 
284
 
285
+ # 编码角色提示词
286
+ character_embeddings = []
287
+ for prompt in character_prompts:
288
+ char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to("cuda")
289
+ char_embeddings = pipe.text_encoder(char_text_input.input_ids.to(device))[0]
290
+ character_embeddings.append(char_embeddings)
291
+
292
+ # 编码互动细节提示词
293
+ details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to("cuda")
294
+ details_embeddings = pipe.text_encoder(details_text_input.input_ids.to(device))[0]
295
+
296
+ # 合并背景和互动细节的嵌入
297
+ prompt_embeddings = torch.cat([bg_embeddings, details_embeddings], dim=1)
298
+
299
+ # 解析角色位置
300
+ character_infos = []
301
+ for position_str in character_positions:
302
+ info = parse_character_position(position_str)
303
+ character_infos.append(info)
304
+
305
+ # 创建角色的掩码
306
+ masks = []
307
+ for info in character_infos:
308
+ mask = create_attention_mask(width, height, info['location'], info['offset'], info['area'])
309
+ masks.append(mask)
310
+
311
+ # 替换注意力处理器
312
+ replace_attention_processors(pipe, masks, character_embeddings, adapter_names)
313
+
314
+ # Generate image
315
+ final_image = generate_image_with_embeddings(prompt_embeddings, steps, seed, cfg_scale, width, height, progress)
316
+
317
+ # 您可以在此处添加上传图片的代码
318
+ result = {"status": "success", "message": "Image generated"}
319
 
320
+ progress(100, "Completed!")
321
 
322
+ return final_image, seed, json.dumps(result)
323
 
324
+ # Gradio 界面
325
  css="""
326
  #col-container {
327
  margin: 0 auto;
 
330
  """
331
 
332
  with gr.Blocks(css=css) as demo:
333
+ gr.Markdown("Flux with LoRA")
334
  with gr.Row():
335
 
336
  with gr.Column():
337
+
338
+ prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2)
339
+ character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
340
+ character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)
341
+ lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
342
+ prompt_details = gr.Text(label="Interaction Details", placeholder="Enter interaction details between characters", lines=2)
343
  run_button = gr.Button("Run", scale=0)
344
 
345
  with gr.Accordion("Advanced Settings", open=False):
 
350
  lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.5)
351
 
352
  with gr.Row():
353
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=512)
354
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=512)
355
 
356
  with gr.Row():
357
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5)
358
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
359
 
360
  upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False)
 
366
 
367
  with gr.Column():
368
  result = gr.Image(label="Result", show_label=False)
369
+ seed_output = gr.Text(label="Seed")
370
+ json_text = gr.Text(label="Result JSON")
371
+
372
+ inputs = [
373
+ prompt_bg,
374
+ character_prompts,
375
+ character_positions,
376
+ lora_strings_json,
377
+ prompt_details,
378
+ cfg_scale,
379
+ steps,
380
+ randomize_seed,
381
+ seed,
382
+ width,
383
+ height,
384
+ lora_scale,
385
+ upload_to_r2,
386
+ account_id,
387
+ access_key,
388
+ secret_key,
389
+ bucket
390
+ ]
391
+
392
+ outputs = [result, seed_output, json_text]
393
+
394
+ run_button.click(
395
+ fn=run_lora,
396
+ inputs=inputs,
397
+ outputs=outputs
398
  )
399
 
400
+ demo.queue().launch()