jiuface commited on
Commit
b9ea7a6
1 Parent(s): cdeb4dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -19
app.py CHANGED
@@ -6,8 +6,10 @@ import spaces
6
  import torch
7
  import json
8
  import logging
9
- from diffusers import DiffusionPipeline
10
  from huggingface_hub import login
 
 
11
  import time
12
  from datetime import datetime
13
  from io import BytesIO
@@ -23,7 +25,6 @@ import json
23
  HF_TOKEN = os.environ.get("HF_TOKEN")
24
  login(token=HF_TOKEN)
25
  import diffusers
26
- print(diffusers.__version__)
27
 
28
  # init
29
  dtype = torch.float16 # use float16 for fast generate
@@ -31,8 +32,23 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
31
  base_model = "black-forest-labs/FLUX.1-dev"
32
 
33
  # load pipe
 
 
34
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  MAX_SEED = 2**32 - 1
37
 
38
  class calculateDuration:
@@ -56,8 +72,7 @@ class calculateDuration:
56
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
57
 
58
  @spaces.GPU(duration=120)
59
- @torch.inference_mode()
60
- def generate_image(prompt, adapter_names, steps, seed, cfg_scale, width, height, progress):
61
 
62
 
63
  gr.Info("Start to generate images ...")
@@ -67,15 +82,28 @@ def generate_image(prompt, adapter_names, steps, seed, cfg_scale, width, height
67
 
68
  with calculateDuration("Generating image"):
69
  # Generate image
70
- generated_image = pipe(
71
- prompt=prompt,
72
- num_inference_steps=steps,
73
- guidance_scale=cfg_scale,
74
- width=width,
75
- height=height,
76
- max_sequence_length=512,
77
- generator=generator,
78
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  progress(99, "Generate image success!")
81
  return generated_image
@@ -119,10 +147,15 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
119
  print("upload thumbnail finish", thumbnail_file)
120
  return image_file
121
 
122
- def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
123
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
124
  gr.Info("Starting process")
125
-
 
 
 
 
 
126
  # Set random seed for reproducibility
127
  if randomize_seed:
128
  with calculateDuration("Set random seed"):
@@ -152,7 +185,10 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
152
  retry_count = 3
153
  for attempt in range(retry_count):
154
  try:
155
- pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
 
 
 
156
  adapter_names.append(adapter_name)
157
  adapter_weights.append(adapter_weight)
158
  break # Load successful, exit retry loop
@@ -165,14 +201,17 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
165
 
166
  # set lora weights
167
  if len(adapter_names) > 0:
168
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
 
 
 
169
 
170
 
171
  # Generate image
172
  error_message = ""
173
  try:
174
  print("Start applying for zeroGPU resources")
175
- final_image = generate_image(prompt, adapter_names, steps, seed, cfg_scale, width, height, progress)
176
  except Exception as e:
177
  error_message = str(e)
178
  gr.Error(error_message)
@@ -210,7 +249,7 @@ with gr.Blocks(css=css) as demo:
210
 
211
  prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=10)
212
  lora_strings_json = gr.Text(label="LoRA Configs (JSON List String)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1", "adapter_weight": 1}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2", "adapter_weight": 1}]', lines=5)
213
-
214
  run_button = gr.Button("Run", scale=0)
215
 
216
  with gr.Accordion("Advanced Settings", open=False):
@@ -224,6 +263,7 @@ with gr.Blocks(css=css) as demo:
224
  height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
225
 
226
  with gr.Row():
 
227
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
228
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
229
 
@@ -244,7 +284,9 @@ with gr.Blocks(css=css) as demo:
244
  )
245
  inputs = [
246
  prompt,
 
247
  lora_strings_json,
 
248
  cfg_scale,
249
  steps,
250
  randomize_seed,
 
6
  import torch
7
  import json
8
  import logging
9
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
10
  from huggingface_hub import login
11
+ from diffusers.utils import load_image
12
+
13
  import time
14
  from datetime import datetime
15
  from io import BytesIO
 
25
  HF_TOKEN = os.environ.get("HF_TOKEN")
26
  login(token=HF_TOKEN)
27
  import diffusers
 
28
 
29
  # init
30
  dtype = torch.float16 # use float16 for fast generate
 
32
  base_model = "black-forest-labs/FLUX.1-dev"
33
 
34
  # load pipe
35
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
36
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
37
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
38
 
39
+ # img2img model
40
+ img2img = AutoPipelineForImage2Image.from_pretrained(base_model,
41
+ vae=good_vae,
42
+ transformer=pipe.transformer,
43
+ text_encoder=pipe.text_encoder,
44
+ tokenizer=pipe.tokenizer,
45
+ text_encoder_2=pipe.text_encoder_2,
46
+ tokenizer_2=pipe.tokenizer_2,
47
+ torch_dtype=dtype
48
+ )
49
+
50
+
51
+
52
  MAX_SEED = 2**32 - 1
53
 
54
  class calculateDuration:
 
72
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
73
 
74
  @spaces.GPU(duration=120)
75
+ def generate_image(orginal_image, prompt, adapter_names, steps, seed, image_strength, cfg_scale, width, height, progress):
 
76
 
77
 
78
  gr.Info("Start to generate images ...")
 
82
 
83
  with calculateDuration("Generating image"):
84
  # Generate image
85
+ if orginal_image:
86
+ generated_image = img2img(
87
+ prompt=prompt,
88
+ image=orginal_image,
89
+ strength=image_strength,
90
+ num_inference_steps=steps,
91
+ guidance_scale=cfg_scale,
92
+ width=width,
93
+ height=height,
94
+ generator=generator,
95
+ joint_attention_kwargs={"scale": lora_scale}
96
+ ).images[0]
97
+ else:
98
+ generated_image = pipe(
99
+ prompt=prompt,
100
+ num_inference_steps=steps,
101
+ guidance_scale=cfg_scale,
102
+ width=width,
103
+ height=height,
104
+ max_sequence_length=512,
105
+ generator=generator,
106
+ ).images[0]
107
 
108
  progress(99, "Generate image success!")
109
  return generated_image
 
147
  print("upload thumbnail finish", thumbnail_file)
148
  return image_file
149
 
150
+ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, steps, randomize_seed, seed, width, height, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
151
  print("run_lora", prompt, lora_strings_json, cfg_scale, steps, width, height)
152
  gr.Info("Starting process")
153
+
154
+ img2img_model = False
155
+ orginal_image = None
156
+ if image_url:
157
+ orginal_image = load_image(image_url)
158
+ img2img_model = True
159
  # Set random seed for reproducibility
160
  if randomize_seed:
161
  with calculateDuration("Set random seed"):
 
185
  retry_count = 3
186
  for attempt in range(retry_count):
187
  try:
188
+ if img2img_model:
189
+ img2img.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
190
+ else:
191
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
192
  adapter_names.append(adapter_name)
193
  adapter_weights.append(adapter_weight)
194
  break # Load successful, exit retry loop
 
201
 
202
  # set lora weights
203
  if len(adapter_names) > 0:
204
+ if img2img_model:
205
+ img2img.set_adapters(adapter_names, adapter_weights=adapter_weights)
206
+ else:
207
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
208
 
209
 
210
  # Generate image
211
  error_message = ""
212
  try:
213
  print("Start applying for zeroGPU resources")
214
+ final_image = generate_image(orginal_image, prompt, adapter_names, steps, seed, image_strength, cfg_scale, width, height, progress)
215
  except Exception as e:
216
  error_message = str(e)
217
  gr.Error(error_message)
 
249
 
250
  prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=10)
251
  lora_strings_json = gr.Text(label="LoRA Configs (JSON List String)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1", "adapter_weight": 1}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2", "adapter_weight": 1}]', lines=5)
252
+ image_url = gr.Text(label="Image url", placeholder="Enter image url to enable image to image model", lines=1)
253
  run_button = gr.Button("Run", scale=0)
254
 
255
  with gr.Accordion("Advanced Settings", open=False):
 
263
  height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
264
 
265
  with gr.Row():
266
+ image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
267
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
268
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
269
 
 
284
  )
285
  inputs = [
286
  prompt,
287
+ image_url,
288
  lora_strings_json,
289
+ image_strength,
290
  cfg_scale,
291
  steps,
292
  randomize_seed,