multimodalart HF staff commited on
Commit
07d0ab0
1 Parent(s): 842f2ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -18
app.py CHANGED
@@ -39,6 +39,10 @@ MAX_IMAGES = 150
39
 
40
 
41
  def load_captioning(uploaded_images, concept_sentence):
 
 
 
 
42
  gr.Info("Images uploaded!")
43
  updates = []
44
  if len(uploaded_images) <= 1:
@@ -54,17 +58,23 @@ def load_captioning(uploaded_images, concept_sentence):
54
  for i in range(1, MAX_IMAGES + 1):
55
  # Determine if the current row and image should be visible
56
  visible = i <= len(uploaded_images)
57
-
58
  # Update visibility of the captioning row
59
  updates.append(gr.update(visible=visible))
60
 
61
  # Update for image component - display image if available, otherwise hide
62
  image_value = uploaded_images[i - 1] if visible else None
63
-
64
  updates.append(gr.update(value=image_value, visible=visible))
65
 
 
 
 
 
 
 
 
66
  # Update value of captioning area
67
- text_value = "[trigger]" if visible and concept_sentence else None
68
  updates.append(gr.update(value=text_value, visible=visible))
69
 
70
  # Update for the sample caption area
@@ -145,6 +155,8 @@ def start_training(
145
  sample_1,
146
  sample_2,
147
  sample_3,
 
 
148
  profile: Union[gr.OAuthProfile, None],
149
  oauth_token: Union[gr.OAuthToken, None],
150
  ):
@@ -197,6 +209,10 @@ def start_training(
197
  config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
198
  else:
199
  config["config"]["process"][0]["train"]["disable_sampling"] = True
 
 
 
 
200
  # Save the updated config
201
  # generate a random name for the config
202
  random_config_name = str(uuid.uuid4())
@@ -232,20 +248,6 @@ def start_training(
232
 
233
  return f"Training completed successfully. Model saved as {slugged_lora_name}"
234
 
235
-
236
- theme = gr.themes.Monochrome(
237
- text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
238
- font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
239
- )
240
- css = """
241
- h1{font-size: 2em}
242
- h3{margin-top: 0}
243
- #component-1{text-align:center}
244
- .main_ui_logged_out{opacity: 0.3; pointer-events: none}
245
- .tabitem{border: 0px}
246
- .group_padding{padding: .55em}
247
- """
248
-
249
  def swap_visibilty(profile: Union[gr.OAuthProfile, None]):
250
  if is_spaces:
251
  if profile is None:
@@ -272,6 +274,64 @@ def update_pricing(steps, oauth_token: Union[gr.OAuthToken, None]):
272
  else:
273
  return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  with gr.Blocks(theme=theme, css=css) as demo:
276
  gr.Markdown(
277
  """# LoRA Ease for FLUX 🧞‍♂️
@@ -295,8 +355,9 @@ with gr.Blocks(theme=theme, css=css) as demo:
295
  )
296
  with gr.Group(visible=True) as image_upload:
297
  with gr.Row():
 
298
  images = gr.File(
299
- file_types=["image"],
300
  label="Upload your images",
301
  file_count="multiple",
302
  interactive=True,
@@ -339,6 +400,11 @@ with gr.Blocks(theme=theme, css=css) as demo:
339
  steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
340
  lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
341
  rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
 
 
 
 
 
342
 
343
  with gr.Accordion("Sample prompts (optional)", visible=False) as sample:
344
  gr.Markdown(
@@ -424,6 +490,8 @@ with gr.Blocks(theme=theme, css=css) as demo:
424
  sample_1,
425
  sample_2,
426
  sample_3,
 
 
427
  ],
428
  outputs=progress_area,
429
  )
 
39
 
40
 
41
  def load_captioning(uploaded_images, concept_sentence):
42
+
43
+ uploaded_images = [file for file in uploaded_images if not file.endswith('.txt')]
44
+ txt_files = [file for file in uploaded_images if file.endswith('.txt')]
45
+
46
  gr.Info("Images uploaded!")
47
  updates = []
48
  if len(uploaded_images) <= 1:
 
58
  for i in range(1, MAX_IMAGES + 1):
59
  # Determine if the current row and image should be visible
60
  visible = i <= len(uploaded_images)
61
+
62
  # Update visibility of the captioning row
63
  updates.append(gr.update(visible=visible))
64
 
65
  # Update for image component - display image if available, otherwise hide
66
  image_value = uploaded_images[i - 1] if visible else None
 
67
  updates.append(gr.update(value=image_value, visible=visible))
68
 
69
+ base_name = image_value.rsplit('.', 1)[0]
70
+
71
+ corresponding_txt = base_name + '.txt'
72
+ corresponding_caption = False
73
+ if corresponding_txt in txt_files:
74
+ with open(corresponding_txt, 'r') as file:
75
+ corresponding_caption = file.read()
76
  # Update value of captioning area
77
+ text_value = corresponding_caption if corresponding_caption else "[trigger]" if visible and concept_sentence else None
78
  updates.append(gr.update(value=text_value, visible=visible))
79
 
80
  # Update for the sample caption area
 
155
  sample_1,
156
  sample_2,
157
  sample_3,
158
+ use_more_advanced_options,
159
+ more_advanced_options,
160
  profile: Union[gr.OAuthProfile, None],
161
  oauth_token: Union[gr.OAuthToken, None],
162
  ):
 
209
  config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
210
  else:
211
  config["config"]["process"][0]["train"]["disable_sampling"] = True
212
+
213
+ if(use_more_advanced_options):
214
+ config["config"]["process"] = more_advanced_options
215
+
216
  # Save the updated config
217
  # generate a random name for the config
218
  random_config_name = str(uuid.uuid4())
 
248
 
249
  return f"Training completed successfully. Model saved as {slugged_lora_name}"
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  def swap_visibilty(profile: Union[gr.OAuthProfile, None]):
252
  if is_spaces:
253
  if profile is None:
 
274
  else:
275
  return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
276
 
277
+ config_yaml = {
278
+ "device": "cuda:0",
279
+ "network": {
280
+ "type": "lora",
281
+ "linear": 16,
282
+ "linear_alpha": 16
283
+ },
284
+ "save": {
285
+ "dtype": "float16",
286
+ "save_every": 10000,
287
+ "max_step_saves_to_keep": 4,
288
+ "push_to_hub": True,
289
+ "hf_private": True
290
+ },
291
+ "train": {
292
+ "batch_size": 1,
293
+ "gradient_accumulation_steps": 1,
294
+ "train_unet": True,
295
+ "train_text_encoder": False,
296
+ "gradient_checkpointing": True,
297
+ "noise_scheduler": "flowmatch",
298
+ "optimizer": "adamw8bit",
299
+ "ema_config": {
300
+ "use_ema": True,
301
+ "ema_decay": 0.99
302
+ },
303
+ "dtype": "bf16"
304
+ },
305
+ "model": {
306
+ "name_or_path": "black-forest-labs/FLUX.1-dev",
307
+ "is_flux": True,
308
+ "quantize": True
309
+ },
310
+ "sample": {
311
+ "sampler": "flowmatch",
312
+ "sample_every": 1000,
313
+ "width": 1024,
314
+ "height": 1024,
315
+ "neg": "",
316
+ "seed": 42,
317
+ "walk_seed": True,
318
+ "guidance_scale": 3.5,
319
+ "sample_steps": 28
320
+ }
321
+ }
322
+
323
+ theme = gr.themes.Monochrome(
324
+ text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
325
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
326
+ )
327
+ css = """
328
+ h1{font-size: 2em}
329
+ h3{margin-top: 0}
330
+ #component-1{text-align:center}
331
+ .main_ui_logged_out{opacity: 0.3; pointer-events: none}
332
+ .tabitem{border: 0px}
333
+ .group_padding{padding: .55em}
334
+ """
335
  with gr.Blocks(theme=theme, css=css) as demo:
336
  gr.Markdown(
337
  """# LoRA Ease for FLUX 🧞‍♂️
 
355
  )
356
  with gr.Group(visible=True) as image_upload:
357
  with gr.Row():
358
+ gr.Markdown("Upload your images to caption them in the UI (if you already have a dataset with .txt captions, upload them together)")
359
  images = gr.File(
360
+ file_types=["image", ".txt"],
361
  label="Upload your images",
362
  file_count="multiple",
363
  interactive=True,
 
400
  steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
401
  lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
402
  rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
403
+ with gr.Accordion("Even more advanced options", open=False):
404
+ if(is_spaces):
405
+ gr.Markdown("Attention: changing this parameters may make your training fail or go out-of-memory if training on Spaces. Only change settings here it if you know what you are doing. Beware that training is done in an L4 GPU with 24GB of RAM")
406
+ use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False)
407
+ more_advanced_options = gr.Code(config_yaml, language="yaml")
408
 
409
  with gr.Accordion("Sample prompts (optional)", visible=False) as sample:
410
  gr.Markdown(
 
490
  sample_1,
491
  sample_2,
492
  sample_3,
493
+ use_more_advanced_options,
494
+ more_advanced_options
495
  ],
496
  outputs=progress_area,
497
  )