fffiloni commited on
Commit
eddc0c5
1 Parent(s): f57f3d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +378 -388
app.py CHANGED
@@ -1,419 +1,409 @@
1
- import json
2
- import logging
3
- import os
4
-
5
- import blobfile as bf
6
  import torch
7
  import gc
8
- from datasets import load_dataset
9
- from pytorch_lightning import seed_everything
10
- from tqdm import tqdm
11
-
12
  from arguments import parse_args
13
- from models import get_model, get_multi_apply_fn
14
- from rewards import get_reward_losses
15
- from training import LatentNoiseTrainer, get_optimizer
16
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- import torch
19
- import gc
 
 
 
 
20
 
21
- def clear_gpu():
22
- """Clear GPU memory by removing tensors, freeing cache, and moving data to CPU."""
23
- # List memory usage before clearing
24
- print(f"Memory allocated before clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
25
- print(f"Memory reserved before clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
26
 
27
- # Force the garbage collector to free unreferenced objects
28
- gc.collect()
29
 
30
- # Move any bound tensors back to CPU if needed
31
- if torch.cuda.is_available():
32
- torch.cuda.empty_cache() # Free up the cached memory
33
- torch.cuda.ipc_collect() # Clear any cross-process memory
34
 
35
- print(f"Memory allocated after clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
36
- print(f"Memory reserved after clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
37
-
38
- def unload_previous_model_if_needed(loaded_model_setup):
39
- """Unload the current model from the GPU and free resources if a new model is being loaded."""
40
- if loaded_model_setup is not None:
41
- print("Unloading previous model from GPU to free memory.")
42
- previous_model = loaded_model_setup[7] # Assuming pipe is at position [7] in the setup
43
- if hasattr(previous_model, 'to') and loaded_model_setup[0].model != "flux":
44
- previous_model.to('cpu') # Move model to CPU to free GPU memory
45
- del previous_model # Delete the reference to the model
46
- clear_gpu() # Clear all remaining GPU memory
47
-
48
- def setup(args, loaded_model_setup=None):
49
- seed_everything(args.seed)
50
- bf.makedirs(f"{args.save_dir}/logs/{args.task}")
51
 
52
- # Set up logging and name settings
53
- logger = logging.getLogger()
54
- logger.handlers.clear() # Clear existing handlers
55
- settings = (
56
- f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}"
57
- f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}"
58
- f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}"
59
- f"_reg{args.reg_weight if args.enable_reg else '0'}"
60
- f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}"
61
- f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}"
62
- f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}"
63
- f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
64
- f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
65
- )
 
66
 
67
- file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
68
- handler = logging.StreamHandler(file_stream)
69
- formatter = logging.Formatter("%(asctime)s - %(message)s")
70
- handler.setFormatter(formatter)
71
- logger.addHandler(handler)
72
- logger.setLevel("INFO")
73
- consoleHandler = logging.StreamHandler()
74
- consoleHandler.setFormatter(formatter)
75
- logger.addHandler(consoleHandler)
76
 
77
- logging.info(args)
 
 
78
 
79
- if args.device_id is not None:
80
- logging.info(f"Using CUDA device {args.device_id}")
81
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
82
- os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
83
-
84
- device = torch.device("cuda")
85
- dtype = torch.float16 if args.dtype == "float16" else torch.float32
86
-
87
- # If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe
88
- if loaded_model_setup and args.model == loaded_model_setup[0].model:
89
- print(f"Reusing model {args.model} from loaded setup.")
90
- trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup
91
-
92
- # Update trainer with the new arguments
93
- trainer.n_iters = args.n_iters
94
- trainer.n_inference_steps = args.n_inference_steps
95
- trainer.seed = args.seed
96
- trainer.save_all_images = args.save_all_images
97
- trainer.no_optim = args.no_optim
98
- trainer.regularize = args.enable_reg
99
- trainer.regularization_weight = args.reg_weight
100
- trainer.grad_clip = args.grad_clip
101
- trainer.log_metrics = args.task == "single" or not args.no_optim
102
- trainer.imageselect = args.imageselect
103
 
104
- # Get latents (this step is still required)
105
- if args.model == "flux":
106
- shape = (1, 16 * 64, 64)
107
- elif args.model != "pixart":
108
- height = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
109
- width = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
110
- shape = (
111
- 1,
112
- trainer.model.unet.in_channels,
113
- height // trainer.model.vae_scale_factor,
114
- width // trainer.model.vae_scale_factor,
115
- )
116
- else:
117
- height = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
118
- width = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
119
- shape = (
120
- 1,
121
- trainer.model.transformer.config.in_channels,
122
- height // trainer.model.vae_scale_factor,
123
- width // trainer.model.vae_scale_factor,
124
- )
125
 
126
- pipe = loaded_model_setup[7]
127
- enable_grad = not args.no_optim
128
-
129
- return args, trainer, device, dtype, shape, enable_grad, settings, pipe
130
-
131
- # Unload previous model and clear GPU resources
132
- unload_previous_model_if_needed(loaded_model_setup)
133
-
134
- # Proceed with full model loading if args.model is different
135
- print(f"Loading new model: {args.model}")
136
-
137
- # Get reward losses
138
- reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
139
-
140
- # Get model and noise trainer
141
- pipe = get_model(
142
- args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
143
- )
144
-
145
- # Final memory cleanup after model loading
146
- torch.cuda.empty_cache()
147
- gc.collect()
148
-
149
- trainer = LatentNoiseTrainer(
150
- reward_losses=reward_losses,
151
- model=pipe,
152
- n_iters=args.n_iters,
153
- n_inference_steps=args.n_inference_steps,
154
- seed=args.seed,
155
- save_all_images=args.save_all_images,
156
- device=device if not args.cpu_offloading else 'cpu', # Use CPU if offloading is enabled
157
- no_optim=args.no_optim,
158
- regularize=args.enable_reg,
159
- regularization_weight=args.reg_weight,
160
- grad_clip=args.grad_clip,
161
- log_metrics=args.task == "single" or not args.no_optim,
162
- imageselect=args.imageselect,
163
- )
164
-
165
- # Create latents
166
- if args.model == "flux":
167
- shape = (1, 16 * 64, 64)
168
- elif args.model != "pixart":
169
- height = pipe.unet.config.sample_size * pipe.vae_scale_factor
170
- width = pipe.unet.config.sample_size * pipe.vae_scale_factor
171
- shape = (
172
- 1,
173
- pipe.unet.in_channels,
174
- height // pipe.vae_scale_factor,
175
- width // pipe.vae_scale_factor,
176
- )
177
- else:
178
- height = pipe.transformer.config.sample_size * pipe.vae_scale_factor
179
- width = pipe.transformer.config.sample_size * pipe.vae_scale_factor
180
- shape = (
181
- 1,
182
- pipe.transformer.config.in_channels,
183
- height // pipe.vae_scale_factor,
184
- width // pipe.vae_scale_factor,
185
- )
186
 
187
- enable_grad = not args.no_optim
 
 
 
188
 
189
- # Final memory cleanup
190
  torch.cuda.empty_cache() # Free up cached memory
191
  gc.collect()
192
 
193
-
194
-
195
- return args, trainer, device, dtype, shape, enable_grad, settings, pipe
196
 
 
 
 
 
 
 
197
 
 
 
198
 
 
199
 
200
- def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback=None):
 
201
 
202
- if args.task == "single":
203
- # Attempt to move the model to GPU if model is not Flux
204
- if args.model != "flux":
205
- if pipe.device != torch.device('cuda'):
206
- pipe.to(device, dtype)
207
- else:
208
- print(f"PIPE:{pipe}")
209
-
210
-
211
- if args.cpu_offloading:
212
- pipe.enable_sequential_cpu_offload()
213
-
214
- #if pipe.device != torch.device('cuda'):
215
- # pipe.to(device, dtype)
216
 
217
- if args.enable_multi_apply:
218
-
219
- multi_apply_fn = get_multi_apply_fn(
220
- model_type=args.multi_step_model,
221
- seed=args.seed,
222
- pipe=pipe,
223
- cache_dir=args.cache_dir,
224
- device=device if not args.cpu_offloading else 'cpu',
225
- dtype=dtype,
226
- )
227
- else:
228
- multi_apply_fn = None
229
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  torch.cuda.empty_cache() # Free up cached memory
231
  gc.collect()
232
-
233
 
234
- init_latents = torch.randn(shape, device=device, dtype=dtype)
235
- latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
236
- optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
237
- save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}"
238
- os.makedirs(f"{save_dir}", exist_ok=True)
239
- init_image, best_image, total_init_rewards, total_best_rewards = trainer.train(
240
- latents, args.prompt, optimizer, save_dir, multi_apply_fn, progress_callback=progress_callback
241
- )
242
- best_image.save(f"{save_dir}/best_image.png")
243
- #init_image.save(f"{save_dir}/init_image.png")
244
-
245
- elif args.task == "example-prompts":
246
- fo = open("assets/example_prompts.txt", "r")
247
- prompts = fo.readlines()
248
- fo.close()
249
- for i, prompt in tqdm(enumerate(prompts)):
250
- # Get new latents and optimizer
251
- init_latents = torch.randn(shape, device=device, dtype=dtype)
252
- latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
253
- optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
254
-
255
- prompt = prompt.strip()
256
- name = f"{i:03d}_{prompt[:150]}.png"
257
- save_dir = f"{args.save_dir}/{args.task}/{settings}/{name}"
258
- os.makedirs(save_dir, exist_ok=True)
259
- init_image, best_image, init_rewards, best_rewards = trainer.train(
260
- latents, prompt, optimizer, save_dir, multi_apply_fn
261
- )
262
- if i == 0:
263
- total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
264
- total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
265
- for k in best_rewards.keys():
266
- total_best_rewards[k] += best_rewards[k]
267
- total_init_rewards[k] += init_rewards[k]
268
- best_image.save(f"{save_dir}/best_image.png")
269
- init_image.save(f"{save_dir}/init_image.png")
270
- logging.info(f"Initial rewards: {init_rewards}")
271
- logging.info(f"Best rewards: {best_rewards}")
272
- for k in total_best_rewards.keys():
273
- total_best_rewards[k] /= len(prompts)
274
- total_init_rewards[k] /= len(prompts)
275
-
276
- # save results to directory
277
- with open(f"{args.save_dir}/example-prompts/{settings}/results.txt", "w") as f:
278
- f.write(
279
- f"Mean initial all rewards: {total_init_rewards}\n"
280
- f"Mean best all rewards: {total_best_rewards}\n"
281
- )
282
- elif args.task == "t2i-compbench":
283
- prompt_list_file = f"../T2I-CompBench/examples/dataset/{args.prompt}.txt"
284
- fo = open(prompt_list_file, "r")
285
- prompts = fo.readlines()
286
- fo.close()
287
- os.makedirs(f"{args.save_dir}/{args.task}/{settings}/samples", exist_ok=True)
288
- for i, prompt in tqdm(enumerate(prompts)):
289
- # Get new latents and optimizer
290
- init_latents = torch.randn(shape, device=device, dtype=dtype)
291
- latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
292
- optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
293
-
294
- prompt = prompt.strip()
295
- init_image, best_image, init_rewards, best_rewards = trainer.train(
296
- latents, prompt, optimizer, None, multi_apply_fn
297
- )
298
- if i == 0:
299
- total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
300
- total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
301
- for k in best_rewards.keys():
302
- total_best_rewards[k] += best_rewards[k]
303
- total_init_rewards[k] += init_rewards[k]
304
- name = f"{prompt}_{i:06d}.png"
305
- best_image.save(f"{args.save_dir}/{args.task}/{settings}/samples/{name}")
306
- logging.info(f"Initial rewards: {init_rewards}")
307
- logging.info(f"Best rewards: {best_rewards}")
308
- for k in total_best_rewards.keys():
309
- total_best_rewards[k] /= len(prompts)
310
- total_init_rewards[k] /= len(prompts)
311
- elif args.task == "parti-prompts":
312
- parti_dataset = load_dataset("nateraw/parti-prompts", split="train")
313
- total_reward_diff = 0.0
314
- total_best_reward = 0.0
315
- total_init_reward = 0.0
316
- total_improved_samples = 0
317
- for index, sample in enumerate(parti_dataset):
318
- os.makedirs(
319
- f"{args.save_dir}/{args.task}/{settings}/{index}", exist_ok=True
320
- )
321
- prompt = sample["Prompt"]
322
- init_image, best_image, init_rewards, best_rewards = trainer.train(
323
- latents, prompt, optimizer, multi_apply_fn
324
- )
325
- best_image.save(
326
- f"{args.save_dir}/{args.task}/{settings}/{index}/best_image.png"
327
- )
328
- open(
329
- f"{args.save_dir}/{args.task}/{settings}/{index}/prompt.txt", "w"
330
- ).write(
331
- f"{prompt} \n Initial Rewards: {init_rewards} \n Best Rewards: {best_rewards}"
332
- )
333
- logging.info(f"Initial rewards: {init_rewards}")
334
- logging.info(f"Best rewards: {best_rewards}")
335
- initial_reward = init_rewards[args.benchmark_reward]
336
- best_reward = best_rewards[args.benchmark_reward]
337
- total_reward_diff += best_reward - initial_reward
338
- total_best_reward += best_reward
339
- total_init_reward += initial_reward
340
- if best_reward < initial_reward:
341
- total_improved_samples += 1
342
- if i == 0:
343
- total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
344
- total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
345
- for k in best_rewards.keys():
346
- total_best_rewards[k] += best_rewards[k]
347
- total_init_rewards[k] += init_rewards[k]
348
- # Get new latents and optimizer
349
- init_latents = torch.randn(shape, device=device, dtype=dtype)
350
- latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
351
- optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
352
- improvement_percentage = total_improved_samples / parti_dataset.num_rows
353
- mean_best_reward = total_best_reward / parti_dataset.num_rows
354
- mean_init_reward = total_init_reward / parti_dataset.num_rows
355
- mean_reward_diff = total_reward_diff / parti_dataset.num_rows
356
- logging.info(
357
- f"Improvement percentage: {improvement_percentage:.4f}, "
358
- f"mean initial reward: {mean_init_reward:.4f}, "
359
- f"mean best reward: {mean_best_reward:.4f}, "
360
- f"mean reward diff: {mean_reward_diff:.4f}"
361
- )
362
- for k in total_best_rewards.keys():
363
- total_best_rewards[k] /= len(parti_dataset)
364
- total_init_rewards[k] /= len(parti_dataset)
365
- # save results
366
- os.makedirs(f"{args.save_dir}/parti-prompts/{settings}", exist_ok=True)
367
- with open(f"{args.save_dir}/parti-prompts/{settings}/results.txt", "w") as f:
368
- f.write(
369
- f"Mean improvement: {improvement_percentage:.4f}, "
370
- f"mean initial reward: {mean_init_reward:.4f}, "
371
- f"mean best reward: {mean_best_reward:.4f}, "
372
- f"mean reward diff: {mean_reward_diff:.4f}\n"
373
- f"Mean initial all rewards: {total_init_rewards}\n"
374
- f"Mean best all rewards: {total_best_rewards}"
375
- )
376
- elif args.task == "geneval":
377
- prompt_list_file = "../geneval/prompts/evaluation_metadata.jsonl"
378
- with open(prompt_list_file) as fp:
379
- metadatas = [json.loads(line) for line in fp]
380
- outdir = f"{args.save_dir}/{args.task}/{settings}"
381
- for index, metadata in enumerate(metadatas):
382
- # Get new latents and optimizer
383
- init_latents = torch.randn(shape, device=device, dtype=dtype)
384
- latents = torch.nn.Parameter(init_latents, requires_grad=True)
385
- optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
386
-
387
- prompt = metadata["prompt"]
388
- init_image, best_image, init_rewards, best_rewards = trainer.train(
389
- latents, prompt, optimizer, None, multi_apply_fn
390
- )
391
- logging.info(f"Initial rewards: {init_rewards}")
392
- logging.info(f"Best rewards: {best_rewards}")
393
- outpath = f"{outdir}/{index:0>5}"
394
- os.makedirs(f"{outpath}/samples", exist_ok=True)
395
- with open(f"{outpath}/metadata.jsonl", "w") as fp:
396
- json.dump(metadata, fp)
397
- best_image.save(f"{outpath}/samples/{args.seed:05}.png")
398
- if i == 0:
399
- total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
400
- total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
401
- for k in best_rewards.keys():
402
- total_best_rewards[k] += best_rewards[k]
403
- total_init_rewards[k] += init_rewards[k]
404
- for k in total_best_rewards.keys():
405
- total_best_rewards[k] /= len(parti_dataset)
406
- total_init_rewards[k] /= len(parti_dataset)
407
  else:
408
- raise ValueError(f"Unknown task {args.task}")
409
- # log total rewards
410
- logging.info(f"Mean initial rewards: {total_init_rewards}")
411
- logging.info(f"Mean best rewards: {total_best_rewards}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
- def main():
414
- args = parse_args()
415
- args, trainer, device, dtype, shape, enable_grad, settings, pipe = setup(args, loaded_model_setup=None)
416
- execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
- if __name__ == "__main__":
419
- main()
 
 
 
 
 
 
1
  import torch
2
  import gc
3
+ import gradio as gr
4
+ from main import setup, execute_task
 
 
5
  from arguments import parse_args
6
+ import os
7
+ import shutil
8
+ import glob
9
+ import time
10
+ import threading
11
+ import argparse
12
+
13
+ def list_iter_images(save_dir):
14
+ # Specify only PNG images
15
+ image_extension = 'png'
16
+
17
+ # Create a list to store the image file paths
18
+ image_paths = []
19
+
20
+ # Use glob to find all PNG image files
21
+ all_images = glob.glob(os.path.join(save_dir, f'*.{image_extension}'))
22
+
23
+ # Filter out 'best_image.png'
24
+ image_paths = [img for img in all_images if os.path.basename(img) != 'best_image.png']
25
+
26
+ return image_paths
27
+
28
+ def clean_dir(save_dir):
29
+ # Check if the directory exists
30
+ if os.path.exists(save_dir):
31
+ # Check if the directory contains any files
32
+ if len(os.listdir(save_dir)) > 0:
33
+ # If it contains files, delete all files in the directory
34
+ for filename in os.listdir(save_dir):
35
+ file_path = os.path.join(save_dir, filename)
36
+ try:
37
+ if os.path.isfile(file_path) or os.path.islink(file_path):
38
+ os.unlink(file_path) # Remove file or symbolic link
39
+ elif os.path.isdir(file_path):
40
+ shutil.rmtree(file_path) # Remove directory and its contents
41
+ except Exception as e:
42
+ print(f"Failed to delete {file_path}. Reason: {e}")
43
+ print(f"All files in {save_dir} have been deleted.")
44
+ else:
45
+ print(f"{save_dir} exists but is empty.")
46
+ else:
47
+ print(f"{save_dir} does not exist.")
48
 
49
+ def start_over(gallery_state):
50
+ torch.cuda.empty_cache() # Free up cached memory
51
+ gc.collect()
52
+ if gallery_state is not None:
53
+ gallery_state = None
54
+ return gallery_state, None, None, gr.update(visible=False)
55
 
56
+ def setup_model(loaded_model_setup, prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
57
+ gr.Info(f"Loading {model} model ...")
 
 
 
58
 
59
+ if prompt is None or prompt == "":
60
+ raise gr.Error("You forgot to provide a prompt !")
61
 
62
+ print(f"LOADED_MODEL SETUP: {loaded_model_setup}")
 
 
 
63
 
64
+ """Clear CUDA memory before starting the training."""
65
+ torch.cuda.empty_cache() # Free up cached memory
66
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # Set up arguments
69
+ args = parse_args()
70
+ args.task = "single"
71
+ args.prompt = prompt
72
+ args.model = model
73
+ args.seed = seed
74
+ args.n_iters = num_iterations
75
+ args.lr = learning_rate
76
+ args.cache_dir = "./HF_model_cache"
77
+ args.save_dir = "./outputs"
78
+ args.save_all_images = True
79
+
80
+ if enable_hps is True:
81
+ args.disable_hps = False
82
+ args.hps_weighting = hps_w
83
 
84
+ if enable_imagereward is True:
85
+ args.disable_imagereward = False
86
+ args.imagereward_weighting = imgrw_w
 
 
 
 
 
 
87
 
88
+ if enable_pickscore is True:
89
+ args.disable_pickscore = False
90
+ args.pickscore_weighting = pcks_w
91
 
92
+ if enable_clip is True:
93
+ args.disable_clip = False
94
+ args.clip_weighting = clip_w
95
+
96
+ if model == "flux":
97
+ args.cpu_offloading = True
98
+ args.enable_multi_apply = True
99
+ args.multi_step_model = "flux"
100
+
101
+ # Check if args are the same as the loaded_model_setup except for the prompt
102
+ if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
103
+ previous_args = loaded_model_setup[0]
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Exclude 'prompt' from comparison
106
+ new_args_dict = {k: v for k, v in args.__dict__.items() if k != 'prompt'}
107
+ prev_args_dict = {k: v for k, v in previous_args.__dict__.items() if k != 'prompt'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ if new_args_dict == prev_args_dict:
110
+ # If the arguments (excluding prompt) are the same, reuse the loaded setup
111
+ print(f"Arguments (excluding prompt) are the same, reusing loaded setup for {model} model.")
112
+
113
+ # Update the prompt in the loaded_model_setup
114
+ loaded_model_setup[0].prompt = prompt
115
+
116
+ yield f"{model} model already loaded with the same configuration.", loaded_model_setup
117
+
118
+ # Attempt to set up the model
119
+ try:
120
+ # If other args differ, proceed with the setup
121
+ args, trainer, device, dtype, shape, enable_grad, settings, pipe = setup(args, loaded_model_setup)
122
+ new_loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings, pipe]
123
+ yield f"{model} model loaded successfully!", new_loaded_setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ except Exception as e:
126
+ print(f"Failed to load {model} model: {e}.")
127
+ yield f"Failed to load {model} model: {e}. You can try again, as it usually finally loads on the second try :)", None
128
+
129
 
130
+ def generate_image(setup_args, num_iterations):
131
  torch.cuda.empty_cache() # Free up cached memory
132
  gc.collect()
133
 
134
+ gr.Info(f"Executing iterations task ...")
 
 
135
 
136
+ args = setup_args[0]
137
+ trainer = setup_args[1]
138
+ device = setup_args[2]
139
+ dtype = setup_args[3]
140
+ shape = setup_args[4]
141
+ enable_grad = setup_args[5]
142
 
143
+ settings = setup_args[6]
144
+ print(f"SETTINGS: {settings}")
145
 
146
+ pipe = setup_args[7]
147
 
148
+ save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}"
149
+ clean_dir(save_dir)
150
 
151
+ try:
152
+ torch.cuda.empty_cache() # Free up cached memory
153
+ gc.collect()
154
+ steps_completed = []
155
+ result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None}
156
+ error_status = {"error_occurred": False} # Shared dictionary to track error status
157
+ thread_status = {"running": False} # Track whether a thread is already running
 
 
 
 
 
 
 
158
 
159
+ def progress_callback(step):
160
+ # Limit redundant prints by checking the step number
161
+ if not steps_completed or step > steps_completed[-1]:
162
+ steps_completed.append(step)
163
+ print(f"Progress: Step {step} completed.")
164
+
165
+ def run_main():
166
+ thread_status["running"] = True # Mark thread as running
167
+ try:
168
+ execute_task(
169
+ args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback
170
+ )
171
+ except torch.cuda.OutOfMemoryError as e:
172
+ print(f"CUDA Out of Memory Error: {e}")
173
+ error_status["error_occurred"] = True
174
+ except RuntimeError as e:
175
+ if 'out of memory' in str(e):
176
+ print(f"Runtime Error: {e}")
177
+ error_status["error_occurred"] = True
178
+ else:
179
+ raise
180
+ finally:
181
+ thread_status["running"] = False # Mark thread as completed
182
+
183
+ if not thread_status["running"]: # Ensure no other thread is running
184
+ main_thread = threading.Thread(target=run_main)
185
+ main_thread.start()
186
+
187
+ last_step_yielded = 0
188
+ while main_thread.is_alive() and not error_status["error_occurred"]:
189
+ # Check if new steps have been completed
190
+ if steps_completed and steps_completed[-1] > last_step_yielded:
191
+ last_step_yielded = steps_completed[-1]
192
+ png_number = last_step_yielded - 1
193
+ # Get the image for this step
194
+ image_path = os.path.join(save_dir, f"{png_number}.png")
195
+ if os.path.exists(image_path):
196
+ yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None)
197
+ else:
198
+ yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None)
199
+ else:
200
+ time.sleep(0.1) # Sleep to prevent busy waiting
201
+
202
+ if error_status["error_occurred"]:
203
+ torch.cuda.empty_cache() # Free up cached memory
204
+ gc.collect()
205
+ yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None)
206
+ else:
207
+ main_thread.join() # Ensure thread completion
208
+ final_image_path = os.path.join(save_dir, "best_image.png")
209
+ if os.path.exists(final_image_path):
210
+ iter_images = list_iter_images(save_dir)
211
+ torch.cuda.empty_cache() # Free up cached memory
212
+ gc.collect()
213
+ time.sleep(0.5)
214
+ yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
215
+ else:
216
+ torch.cuda.empty_cache() # Free up cached memory
217
+ gc.collect()
218
+ yield (None, "Image generation completed, but no final image was found.", None)
219
+
220
  torch.cuda.empty_cache() # Free up cached memory
221
  gc.collect()
 
222
 
223
+ except torch.cuda.OutOfMemoryError as e:
224
+ print(f"Global CUDA Out of Memory Error: {e}")
225
+ yield (None, f"{e}", None)
226
+ except RuntimeError as e:
227
+ if 'out of memory' in str(e):
228
+ print(f"Runtime Error: {e}")
229
+ yield (None, f"{e}", None)
230
+ else:
231
+ yield (None, f"An error occurred: {str(e)}", None)
232
+ except Exception as e:
233
+ print(f"Unexpected Error: {e}")
234
+ yield (None, f"An unexpected error occurred: {str(e)}", None)
235
+
236
+ def show_gallery_output(gallery_state):
237
+ if gallery_state is not None:
238
+ return gr.update(value=gallery_state, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  else:
240
+ return gr.update(value=None, visible=False)
241
+
242
+ def combined_function(gallery_state, loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
243
+ # Step 1: Start Over
244
+ gallery_state, output_image, status, iter_gallery_update = start_over(gallery_state)
245
+ model_status = "" # No model status yet
246
+ yield gallery_state, output_image, status, iter_gallery_update, loaded_model_setup, model_status
247
+
248
+ # Step 2: Setup the model
249
+ model_status, new_loaded_model_setup = None, None
250
+ for model_status, new_loaded_model_setup in setup_model(
251
+ loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w,
252
+ enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate):
253
+ yield gallery_state, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status
254
+
255
+ # Step 3: Generate the image
256
+ output_image, status, gallery_state_update = None, None, None
257
+ for output_image, status, gallery_state_update in generate_image(new_loaded_model_setup, n_iter):
258
+ yield gallery_state_update, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status
259
+
260
+ # Step 4: Show the gallery
261
+ iter_gallery_update = show_gallery_output(gallery_state_update)
262
+ yield gallery_state_update, output_image, status, iter_gallery_update, new_loaded_model_setup, model_status
263
+
264
+
265
+ # Create Gradio interface
266
+ title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
267
+ description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
268
+
269
+ css="""
270
+ #model-status-id{
271
+ height: 126px;
272
+ }
273
+ #model-status-id .progress-text{
274
+ font-size: 10px!important;
275
+ }
276
+ #model-status-id .progress-level-inner{
277
+ font-size: 8px!important;
278
+ }
279
+ """
280
+
281
+ with gr.Blocks(css=css, analytics_enabled=False) as demo:
282
+ loaded_model_setup = gr.State()
283
+ gallery_state = gr.State()
284
+ with gr.Column():
285
+ gr.Markdown(title)
286
+ gr.Markdown(description)
287
+ gr.HTML("""
288
+ <div style="display:flex;column-gap:4px;">
289
+ <a href='https://github.com/ExplainableML/ReNO'>
290
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
291
+ </a>
292
+ <a href='https://arxiv.org/abs/2406.04312v1'>
293
+ <img src='https://img.shields.io/badge/Paper-Arxiv-red'>
294
+ </a>
295
+ </div>
296
+ """)
297
+
298
+ with gr.Row():
299
+ with gr.Column():
300
+ prompt = gr.Textbox(label="Prompt")
301
+ with gr.Row():
302
+ chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo")
303
+ seed = gr.Number(label="seed", value=0)
304
+
305
+ model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id")
306
+
307
+ with gr.Row():
308
+ n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=10, label="Number of Iterations")
309
+ learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
310
+
311
+ with gr.Accordion("Advanced Settings", open=True):
312
+ with gr.Column():
313
+ with gr.Row():
314
+ enable_hps = gr.Checkbox(label="HPS ON", value=False, scale=1)
315
+ hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3)
316
+ with gr.Row():
317
+ enable_imagereward = gr.Checkbox(label="ImageReward ON", value=False, scale=1)
318
+ imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3)
319
+ with gr.Row():
320
+ enable_pickscore = gr.Checkbox(label="PickScore ON", value=False, scale=1)
321
+ pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05, interactive=False, scale=3)
322
+ with gr.Row():
323
+ enable_clip = gr.Checkbox(label="CLIP ON", value=False, scale=1)
324
+ clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3)
325
+
326
+ submit_btn = gr.Button("Submit")
327
+
328
+ gr.Examples(
329
+ examples = [
330
+ "A red dog and a green cat",
331
+ "A pink elephant and a grey cow",
332
+ "A toaster riding a bike",
333
+ "Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski",
334
+ "A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions",
335
+ "An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains"
336
+ ],
337
+ inputs = [prompt]
338
+ )
339
+
340
+ with gr.Column():
341
+ output_image = gr.Image(type="filepath", label="Best Generated Image")
342
+ status = gr.Textbox(label="Status")
343
+ iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False)
344
+
345
+ def allow_weighting(weight_type):
346
+ if weight_type is True:
347
+ return gr.update(interactive=True)
348
+ else:
349
+ return gr.update(interactive=False)
350
+
351
+ enable_hps.change(
352
+ fn = allow_weighting,
353
+ inputs = [enable_hps],
354
+ outputs = [hps_w],
355
+ queue = False
356
+ )
357
+ enable_imagereward.change(
358
+ fn = allow_weighting,
359
+ inputs = [enable_imagereward],
360
+ outputs = [imgrw_w],
361
+ queue = False
362
+ )
363
+ enable_pickscore.change(
364
+ fn = allow_weighting,
365
+ inputs = [enable_pickscore],
366
+ outputs = [pcks_w],
367
+ queue = False
368
+ )
369
+ enable_clip.change(
370
+ fn = allow_weighting,
371
+ inputs = [enable_clip],
372
+ outputs = [clip_w],
373
+ queue = False
374
+ )
375
 
376
+ submit_btn.click(
377
+ fn = combined_function,
378
+ inputs = [
379
+ gallery_state, loaded_model_setup, prompt, chosen_model, seed, n_iter,
380
+ enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore,
381
+ pcks_w, enable_clip, clip_w, learning_rate
382
+ ],
383
+ outputs = [
384
+ gallery_state, output_image, status, iter_gallery, loaded_model_setup, model_status # Ensure `model_status` is included in the outputs
385
+ ]
386
+ )
387
+
388
+ """
389
+ submit_btn.click(
390
+ fn = start_over,
391
+ inputs =[gallery_state],
392
+ outputs = [gallery_state, output_image, status, iter_gallery]
393
+ ).then(
394
+ fn = setup_model,
395
+ inputs = [loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate],
396
+ outputs = [model_status, loaded_model_setup] # Load the new setup into the state
397
+ ).then(
398
+ fn = generate_image,
399
+ inputs = [loaded_model_setup, n_iter],
400
+ outputs = [output_image, status, gallery_state]
401
+ ).then(
402
+ fn = show_gallery_output,
403
+ inputs = [gallery_state],
404
+ outputs = iter_gallery
405
+ )
406
+ """
407
 
408
+ # Launch the app
409
+ demo.queue().launch(show_error=True, show_api=False)