fffiloni commited on
Commit
dd551fd
1 Parent(s): 48a11d1

model setup optimizations

Browse files
Files changed (1) hide show
  1. main.py +61 -6
main.py CHANGED
@@ -15,12 +15,11 @@ from rewards import get_reward_losses
15
  from training import LatentNoiseTrainer, get_optimizer
16
 
17
 
18
- def setup(args):
19
-
20
  seed_everything(args.seed)
21
  bf.makedirs(f"{args.save_dir}/logs/{args.task}")
 
22
  # Set up logging and name settings
23
- # Get the root logger and clear existing handlers
24
  logger = logging.getLogger()
25
  logger.handlers.clear() # Clear existing handlers
26
  settings = (
@@ -34,6 +33,7 @@ def setup(args):
34
  f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
35
  f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
36
  )
 
37
  file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
38
  handler = logging.StreamHandler(file_stream)
39
  formatter = logging.Formatter("%(asctime)s - %(message)s")
@@ -43,16 +43,68 @@ def setup(args):
43
  consoleHandler = logging.StreamHandler()
44
  consoleHandler.setFormatter(formatter)
45
  logger.addHandler(consoleHandler)
 
46
  logging.info(args)
 
47
  if args.device_id is not None:
48
  logging.info(f"Using CUDA device {args.device_id}")
49
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
50
  os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
 
51
  device = torch.device("cuda")
52
  if args.dtype == "float32":
53
  dtype = torch.float32
54
  elif args.dtype == "float16":
55
  dtype = torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Get reward losses
57
  reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
58
 
@@ -63,7 +115,7 @@ def setup(args):
63
 
64
  torch.cuda.empty_cache() # Free up cached memory
65
  gc.collect()
66
-
67
  trainer = LatentNoiseTrainer(
68
  reward_losses=reward_losses,
69
  model=pipe,
@@ -85,7 +137,6 @@ def setup(args):
85
 
86
  # Create latents
87
  if args.model == "flux":
88
- # currently only support 512x512 generation
89
  shape = (1, 16 * 64, 64)
90
  elif args.model != "pixart":
91
  height = pipe.unet.config.sample_size * pipe.vae_scale_factor
@@ -107,6 +158,9 @@ def setup(args):
107
  )
108
 
109
  enable_grad = not args.no_optim
 
 
 
110
 
111
  if args.enable_multi_apply:
112
  multi_apply_fn = get_multi_apply_fn(
@@ -121,6 +175,7 @@ def setup(args):
121
  multi_apply_fn = None
122
 
123
  torch.cuda.empty_cache() # Free up cached memory
 
124
 
125
  return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
126
 
@@ -308,7 +363,7 @@ def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_f
308
 
309
  def main():
310
  args = parse_args()
311
- args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args)
312
  execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings)
313
 
314
  if __name__ == "__main__":
 
15
  from training import LatentNoiseTrainer, get_optimizer
16
 
17
 
18
+ def setup(args, loaded_model_setup=None):
 
19
  seed_everything(args.seed)
20
  bf.makedirs(f"{args.save_dir}/logs/{args.task}")
21
+
22
  # Set up logging and name settings
 
23
  logger = logging.getLogger()
24
  logger.handlers.clear() # Clear existing handlers
25
  settings = (
 
33
  f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
34
  f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
35
  )
36
+
37
  file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
38
  handler = logging.StreamHandler(file_stream)
39
  formatter = logging.Formatter("%(asctime)s - %(message)s")
 
43
  consoleHandler = logging.StreamHandler()
44
  consoleHandler.setFormatter(formatter)
45
  logger.addHandler(consoleHandler)
46
+
47
  logging.info(args)
48
+
49
  if args.device_id is not None:
50
  logging.info(f"Using CUDA device {args.device_id}")
51
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
52
  os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
53
+
54
  device = torch.device("cuda")
55
  if args.dtype == "float32":
56
  dtype = torch.float32
57
  elif args.dtype == "float16":
58
  dtype = torch.float16
59
+
60
+ # If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe
61
+ if loaded_model_setup and args.model == loaded_model_setup[0].model:
62
+ # Reuse the trainer and pipe from the loaded model setup
63
+ print(f"Reusing model {args.model} from loaded setup.")
64
+ trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup
65
+
66
+ # Update trainer with the new arguments
67
+ trainer.n_iters = args.n_iters
68
+ trainer.n_inference_steps = args.n_inference_steps
69
+ trainer.seed = args.seed
70
+ trainer.save_all_images = args.save_all_images
71
+ trainer.no_optim = args.no_optim
72
+ trainer.regularize = args.enable_reg
73
+ trainer.regularization_weight = args.reg_weight
74
+ trainer.grad_clip = args.grad_clip
75
+ trainer.log_metrics = args.task == "single" or not args.no_optim
76
+ trainer.imageselect = args.imageselect
77
+
78
+ # Get latents (this step is still required)
79
+ if args.model == "flux":
80
+ shape = (1, 16 * 64, 64)
81
+ elif args.model != "pixart":
82
+ height = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
83
+ width = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
84
+ shape = (
85
+ 1,
86
+ trainer.model.unet.in_channels,
87
+ height // trainer.model.vae_scale_factor,
88
+ width // trainer.model.vae_scale_factor,
89
+ )
90
+ else:
91
+ height = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
92
+ width = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
93
+ shape = (
94
+ 1,
95
+ trainer.model.transformer.config.in_channels,
96
+ height // trainer.model.vae_scale_factor,
97
+ width // trainer.model.vae_scale_factor,
98
+ )
99
+
100
+ multi_apply_fn = loaded_model_setup[6]
101
+ enable_grad = not args.no_optim
102
+
103
+ return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
104
+
105
+ # Proceed with full model loading if args.model is different
106
+ print(f"Loading new model: {args.model}")
107
+
108
  # Get reward losses
109
  reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
110
 
 
115
 
116
  torch.cuda.empty_cache() # Free up cached memory
117
  gc.collect()
118
+
119
  trainer = LatentNoiseTrainer(
120
  reward_losses=reward_losses,
121
  model=pipe,
 
137
 
138
  # Create latents
139
  if args.model == "flux":
 
140
  shape = (1, 16 * 64, 64)
141
  elif args.model != "pixart":
142
  height = pipe.unet.config.sample_size * pipe.vae_scale_factor
 
158
  )
159
 
160
  enable_grad = not args.no_optim
161
+
162
+ torch.cuda.empty_cache() # Free up cached memory
163
+ gc.collect()
164
 
165
  if args.enable_multi_apply:
166
  multi_apply_fn = get_multi_apply_fn(
 
175
  multi_apply_fn = None
176
 
177
  torch.cuda.empty_cache() # Free up cached memory
178
+ gc.collect()
179
 
180
  return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
181
 
 
363
 
364
  def main():
365
  args = parse_args()
366
+ args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args, loaded_model_setup=None)
367
  execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings)
368
 
369
  if __name__ == "__main__":