Spaces:
Sleeping
Sleeping
model setup optimizations
Browse files
main.py
CHANGED
@@ -15,12 +15,11 @@ from rewards import get_reward_losses
|
|
15 |
from training import LatentNoiseTrainer, get_optimizer
|
16 |
|
17 |
|
18 |
-
def setup(args):
|
19 |
-
|
20 |
seed_everything(args.seed)
|
21 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
|
|
22 |
# Set up logging and name settings
|
23 |
-
# Get the root logger and clear existing handlers
|
24 |
logger = logging.getLogger()
|
25 |
logger.handlers.clear() # Clear existing handlers
|
26 |
settings = (
|
@@ -34,6 +33,7 @@ def setup(args):
|
|
34 |
f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
|
35 |
f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
|
36 |
)
|
|
|
37 |
file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
|
38 |
handler = logging.StreamHandler(file_stream)
|
39 |
formatter = logging.Formatter("%(asctime)s - %(message)s")
|
@@ -43,16 +43,68 @@ def setup(args):
|
|
43 |
consoleHandler = logging.StreamHandler()
|
44 |
consoleHandler.setFormatter(formatter)
|
45 |
logger.addHandler(consoleHandler)
|
|
|
46 |
logging.info(args)
|
|
|
47 |
if args.device_id is not None:
|
48 |
logging.info(f"Using CUDA device {args.device_id}")
|
49 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
50 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
|
|
|
51 |
device = torch.device("cuda")
|
52 |
if args.dtype == "float32":
|
53 |
dtype = torch.float32
|
54 |
elif args.dtype == "float16":
|
55 |
dtype = torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Get reward losses
|
57 |
reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
|
58 |
|
@@ -63,7 +115,7 @@ def setup(args):
|
|
63 |
|
64 |
torch.cuda.empty_cache() # Free up cached memory
|
65 |
gc.collect()
|
66 |
-
|
67 |
trainer = LatentNoiseTrainer(
|
68 |
reward_losses=reward_losses,
|
69 |
model=pipe,
|
@@ -85,7 +137,6 @@ def setup(args):
|
|
85 |
|
86 |
# Create latents
|
87 |
if args.model == "flux":
|
88 |
-
# currently only support 512x512 generation
|
89 |
shape = (1, 16 * 64, 64)
|
90 |
elif args.model != "pixart":
|
91 |
height = pipe.unet.config.sample_size * pipe.vae_scale_factor
|
@@ -107,6 +158,9 @@ def setup(args):
|
|
107 |
)
|
108 |
|
109 |
enable_grad = not args.no_optim
|
|
|
|
|
|
|
110 |
|
111 |
if args.enable_multi_apply:
|
112 |
multi_apply_fn = get_multi_apply_fn(
|
@@ -121,6 +175,7 @@ def setup(args):
|
|
121 |
multi_apply_fn = None
|
122 |
|
123 |
torch.cuda.empty_cache() # Free up cached memory
|
|
|
124 |
|
125 |
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
126 |
|
@@ -308,7 +363,7 @@ def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_f
|
|
308 |
|
309 |
def main():
|
310 |
args = parse_args()
|
311 |
-
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args)
|
312 |
execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings)
|
313 |
|
314 |
if __name__ == "__main__":
|
|
|
15 |
from training import LatentNoiseTrainer, get_optimizer
|
16 |
|
17 |
|
18 |
+
def setup(args, loaded_model_setup=None):
|
|
|
19 |
seed_everything(args.seed)
|
20 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
21 |
+
|
22 |
# Set up logging and name settings
|
|
|
23 |
logger = logging.getLogger()
|
24 |
logger.handlers.clear() # Clear existing handlers
|
25 |
settings = (
|
|
|
33 |
f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
|
34 |
f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
|
35 |
)
|
36 |
+
|
37 |
file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
|
38 |
handler = logging.StreamHandler(file_stream)
|
39 |
formatter = logging.Formatter("%(asctime)s - %(message)s")
|
|
|
43 |
consoleHandler = logging.StreamHandler()
|
44 |
consoleHandler.setFormatter(formatter)
|
45 |
logger.addHandler(consoleHandler)
|
46 |
+
|
47 |
logging.info(args)
|
48 |
+
|
49 |
if args.device_id is not None:
|
50 |
logging.info(f"Using CUDA device {args.device_id}")
|
51 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
52 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
|
53 |
+
|
54 |
device = torch.device("cuda")
|
55 |
if args.dtype == "float32":
|
56 |
dtype = torch.float32
|
57 |
elif args.dtype == "float16":
|
58 |
dtype = torch.float16
|
59 |
+
|
60 |
+
# If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe
|
61 |
+
if loaded_model_setup and args.model == loaded_model_setup[0].model:
|
62 |
+
# Reuse the trainer and pipe from the loaded model setup
|
63 |
+
print(f"Reusing model {args.model} from loaded setup.")
|
64 |
+
trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup
|
65 |
+
|
66 |
+
# Update trainer with the new arguments
|
67 |
+
trainer.n_iters = args.n_iters
|
68 |
+
trainer.n_inference_steps = args.n_inference_steps
|
69 |
+
trainer.seed = args.seed
|
70 |
+
trainer.save_all_images = args.save_all_images
|
71 |
+
trainer.no_optim = args.no_optim
|
72 |
+
trainer.regularize = args.enable_reg
|
73 |
+
trainer.regularization_weight = args.reg_weight
|
74 |
+
trainer.grad_clip = args.grad_clip
|
75 |
+
trainer.log_metrics = args.task == "single" or not args.no_optim
|
76 |
+
trainer.imageselect = args.imageselect
|
77 |
+
|
78 |
+
# Get latents (this step is still required)
|
79 |
+
if args.model == "flux":
|
80 |
+
shape = (1, 16 * 64, 64)
|
81 |
+
elif args.model != "pixart":
|
82 |
+
height = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
|
83 |
+
width = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
|
84 |
+
shape = (
|
85 |
+
1,
|
86 |
+
trainer.model.unet.in_channels,
|
87 |
+
height // trainer.model.vae_scale_factor,
|
88 |
+
width // trainer.model.vae_scale_factor,
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
height = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
|
92 |
+
width = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
|
93 |
+
shape = (
|
94 |
+
1,
|
95 |
+
trainer.model.transformer.config.in_channels,
|
96 |
+
height // trainer.model.vae_scale_factor,
|
97 |
+
width // trainer.model.vae_scale_factor,
|
98 |
+
)
|
99 |
+
|
100 |
+
multi_apply_fn = loaded_model_setup[6]
|
101 |
+
enable_grad = not args.no_optim
|
102 |
+
|
103 |
+
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
104 |
+
|
105 |
+
# Proceed with full model loading if args.model is different
|
106 |
+
print(f"Loading new model: {args.model}")
|
107 |
+
|
108 |
# Get reward losses
|
109 |
reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
|
110 |
|
|
|
115 |
|
116 |
torch.cuda.empty_cache() # Free up cached memory
|
117 |
gc.collect()
|
118 |
+
|
119 |
trainer = LatentNoiseTrainer(
|
120 |
reward_losses=reward_losses,
|
121 |
model=pipe,
|
|
|
137 |
|
138 |
# Create latents
|
139 |
if args.model == "flux":
|
|
|
140 |
shape = (1, 16 * 64, 64)
|
141 |
elif args.model != "pixart":
|
142 |
height = pipe.unet.config.sample_size * pipe.vae_scale_factor
|
|
|
158 |
)
|
159 |
|
160 |
enable_grad = not args.no_optim
|
161 |
+
|
162 |
+
torch.cuda.empty_cache() # Free up cached memory
|
163 |
+
gc.collect()
|
164 |
|
165 |
if args.enable_multi_apply:
|
166 |
multi_apply_fn = get_multi_apply_fn(
|
|
|
175 |
multi_apply_fn = None
|
176 |
|
177 |
torch.cuda.empty_cache() # Free up cached memory
|
178 |
+
gc.collect()
|
179 |
|
180 |
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
181 |
|
|
|
363 |
|
364 |
def main():
|
365 |
args = parse_args()
|
366 |
+
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args, loaded_model_setup=None)
|
367 |
execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings)
|
368 |
|
369 |
if __name__ == "__main__":
|