multimodalart HF staff commited on
Commit
c10e71a
1 Parent(s): 867b96b

class images as zip

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -16,6 +16,8 @@ import importlib
16
  import sys
17
  from pathlib import Path
18
  import spaces
 
 
19
  MAX_IMAGES = 50
20
 
21
  training_script_url = "https://raw.githubusercontent.com/huggingface/diffusers/add-peft-to-advanced-training-script/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py"
@@ -288,12 +290,16 @@ def start_training(
288
  commands.append(f"num_class_images={int(num_class_images)}")
289
  if class_images:
290
  class_folder = str(uuid.uuid4())
291
- if not os.path.exists(class_folder):
292
- os.makedirs(class_folder)
293
- for image in class_images:
294
- shutil.copy(image, class_folder)
 
 
 
 
 
295
  commands.append(f"class_data_dir={class_folder}")
296
- shutil.copytree(class_folder, f"{spacerunner_folder}/{class_folder}")
297
  if use_prodigy_beta3:
298
  commands.append(f"prodigy_beta3={prodigy_beta3}")
299
  if use_adam_weight_decay_text_encoder:
 
16
  import sys
17
  from pathlib import Path
18
  import spaces
19
+ import zipfile
20
+
21
  MAX_IMAGES = 50
22
 
23
  training_script_url = "https://raw.githubusercontent.com/huggingface/diffusers/add-peft-to-advanced-training-script/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py"
 
290
  commands.append(f"num_class_images={int(num_class_images)}")
291
  if class_images:
292
  class_folder = str(uuid.uuid4())
293
+ zip_path = os.path.join(spacerunner_folder, class_folder, "class_images.zip")
294
+
295
+ if not os.path.exists(os.path.join(spacerunner_folder, class_folder)):
296
+ os.makedirs(os.path.join(spacerunner_folder, class_folder))
297
+
298
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
299
+ for image in class_images:
300
+ zipf.write(image, os.path.basename(image))
301
+
302
  commands.append(f"class_data_dir={class_folder}")
 
303
  if use_prodigy_beta3:
304
  commands.append(f"prodigy_beta3={prodigy_beta3}")
305
  if use_adam_weight_decay_text_encoder: