jiuface commited on
Commit
a9da525
1 Parent(s): f289e28

support multi loras

Browse files
__pycache__/live_preview_helpers.cpython-310.pyc CHANGED
Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ
 
app.py CHANGED
@@ -18,6 +18,9 @@ import boto3
18
  from io import BytesIO
19
  from datetime import datetime
20
 
 
 
 
21
 
22
  HF_TOKEN = os.environ.get("HF_TOKEN")
23
 
@@ -27,7 +30,20 @@ login(token=HF_TOKEN)
27
  dtype = torch.bfloat16
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  base_model = "black-forest-labs/FLUX.1-dev"
 
 
 
 
 
 
 
 
 
 
30
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
 
 
 
31
  MAX_SEED = 2**32-1
32
 
33
  class calculateDuration:
@@ -70,7 +86,7 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
70
 
71
 
72
  @spaces.GPU
73
- def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale, progress):
74
  pipe.to("cuda")
75
  generator = torch.Generator(device="cuda").manual_seed(seed)
76
  with calculateDuration("Generating image"):
@@ -82,7 +98,7 @@ def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale, pr
82
  width=width,
83
  height=height,
84
  generator=generator,
85
- joint_attention_kwargs={"scale": lora_scale},
86
  max_sequence_length=256
87
  ).images[0]
88
 
@@ -90,21 +106,34 @@ def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale, pr
90
  return generate_image
91
 
92
 
93
- def run_lora(prompt, cfg_scale, steps, lora_repo, lora_name, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
94
 
95
- # with calculateDuration("Unloading LoRA"):
96
- # pipe.unload_lora_weights()
97
-
98
- # Load LoRA weights
99
- if lora_repo and lora_name:
100
- with calculateDuration(f"Loading LoRA weights for {lora_repo} {lora_name}"):
101
- pipe.load_lora_weights(lora_repo, weight_name=lora_name)
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # Set random seed for reproducibility
104
  if randomize_seed:
105
  seed = random.randint(0, MAX_SEED)
106
 
107
- final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale, progress)
108
 
109
  if upload_to_r2:
110
  with calculateDuration("upload r2"):
@@ -131,8 +160,7 @@ with gr.Blocks(css=css) as demo:
131
 
132
  with gr.Column():
133
  prompt = gr.Text(label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False)
134
- lora_repo = gr.Text( label="Repo", max_lines=1, placeholder="Enter a lora repo", visible=True)
135
- lora_name = gr.Text( label="Weights", max_lines=1, placeholder="Enter a lora weights",visible=True)
136
  run_button = gr.Button("Run", scale=0)
137
 
138
  with gr.Accordion("Advanced Settings", open=False):
@@ -164,7 +192,7 @@ with gr.Blocks(css=css) as demo:
164
  gr.on(
165
  triggers=[run_button.click, prompt.submit],
166
  fn = run_lora,
167
- inputs = [prompt, cfg_scale, steps, lora_repo, lora_name, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket],
168
  outputs=[result, seed, json_text]
169
  )
170
 
 
18
  from io import BytesIO
19
  from datetime import datetime
20
 
21
+ from diffusers import UNet2DConditionModel
22
+
23
+
24
 
25
  HF_TOKEN = os.environ.get("HF_TOKEN")
26
 
 
30
  dtype = torch.bfloat16
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  base_model = "black-forest-labs/FLUX.1-dev"
33
+
34
+ # unet = UNet2DConditionModel.from_pretrained(
35
+ # base_model,
36
+ # torch_dtype=torch.float16,
37
+ # use_safetensors=True,
38
+ # variant="fp16",
39
+ # subfolder="unet",
40
+ # ).to("cuda")
41
+
42
+
43
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
44
+
45
+
46
+
47
  MAX_SEED = 2**32-1
48
 
49
  class calculateDuration:
 
86
 
87
 
88
  @spaces.GPU
89
+ def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
90
  pipe.to("cuda")
91
  generator = torch.Generator(device="cuda").manual_seed(seed)
92
  with calculateDuration("Generating image"):
 
98
  width=width,
99
  height=height,
100
  generator=generator,
101
+ cross_attention_kwargs={"scale": 1.0},
102
  max_sequence_length=256
103
  ).images[0]
104
 
 
106
  return generate_image
107
 
108
 
109
+ def run_lora(prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
110
 
 
 
 
 
 
 
 
111
 
112
+ # Load LoRA weights
113
+ if lora_strings:
114
+ with calculateDuration(f"Loading LoRA weights for {lora_strings}"):
115
+ pipe.unload_lora_weights()
116
+ lora_array = lora_strings.split(',')
117
+ adapter_names = []
118
+ for lora_string in lora_array:
119
+ parts = lora_string.split(':')
120
+ if len(parts) == 3:
121
+ lora_repo, weights, adapter_name = parts
122
+ # 调用 pipe.load_lora_weights() 方法加载权重
123
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
124
+ adapter_names.append(adapter_name)
125
+ else:
126
+ print(f"Invalid format for lora_string: {lora_string}")
127
+
128
+ adapter_weights = [lora_scale] * len(adapter_names)
129
+ # 调用 pipeline.set_adapters 方法设置 adapter 和对应权重
130
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
131
+
132
  # Set random seed for reproducibility
133
  if randomize_seed:
134
  seed = random.randint(0, MAX_SEED)
135
 
136
+ final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
137
 
138
  if upload_to_r2:
139
  with calculateDuration("upload r2"):
 
160
 
161
  with gr.Column():
162
  prompt = gr.Text(label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False)
163
+ lora_strings = gr.Text( label="lora_strings", max_lines=1, placeholder="Enter a lora strings", visible=True)
 
164
  run_button = gr.Button("Run", scale=0)
165
 
166
  with gr.Accordion("Advanced Settings", open=False):
 
192
  gr.on(
193
  triggers=[run_button.click, prompt.submit],
194
  fn = run_lora,
195
+ inputs = [prompt, cfg_scale, steps, lora_strings, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket],
196
  outputs=[result, seed, json_text]
197
  )
198