Spaces:
Sleeping
Sleeping
better GPU memory management
Browse files
main.py
CHANGED
@@ -15,35 +15,83 @@ from rewards import get_reward_losses
|
|
15 |
from training import LatentNoiseTrainer, get_optimizer
|
16 |
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def clear_gpu():
|
22 |
"""Clear GPU memory by removing tensors, freeing cache, and moving data to CPU."""
|
23 |
# List memory usage before clearing
|
24 |
print(f"Memory allocated before clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
|
25 |
print(f"Memory reserved before clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
|
26 |
-
|
27 |
-
# Force the garbage collector to free unreferenced objects
|
28 |
-
gc.collect()
|
29 |
|
30 |
# Move any bound tensors back to CPU if needed
|
31 |
if torch.cuda.is_available():
|
32 |
-
torch.cuda.empty_cache()
|
33 |
-
torch.cuda.
|
|
|
34 |
|
35 |
print(f"Memory allocated after clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
|
36 |
print(f"Memory reserved after clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
|
37 |
|
38 |
def unload_previous_model_if_needed(loaded_model_setup):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
"""Unload the current model from the GPU and free resources if a new model is being loaded."""
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
previous_model.to('cpu') # Move model to CPU to free GPU memory
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
def setup(args, loaded_model_setup=None):
|
49 |
seed_everything(args.seed)
|
@@ -144,7 +192,6 @@ def setup(args, loaded_model_setup=None):
|
|
144 |
|
145 |
# Final memory cleanup after model loading
|
146 |
torch.cuda.empty_cache()
|
147 |
-
gc.collect()
|
148 |
|
149 |
trainer = LatentNoiseTrainer(
|
150 |
reward_losses=reward_losses,
|
@@ -188,7 +235,7 @@ def setup(args, loaded_model_setup=None):
|
|
188 |
|
189 |
# Final memory cleanup
|
190 |
torch.cuda.empty_cache() # Free up cached memory
|
191 |
-
|
192 |
|
193 |
|
194 |
|
@@ -200,23 +247,30 @@ def setup(args, loaded_model_setup=None):
|
|
200 |
def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback=None):
|
201 |
|
202 |
if args.task == "single":
|
|
|
|
|
|
|
203 |
# Attempt to move the model to GPU if model is not Flux
|
204 |
if args.model != "flux":
|
205 |
-
if args.model
|
206 |
if pipe.device != torch.device('cuda'):
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
209 |
if pipe.device != torch.device('cuda'):
|
210 |
pipe.to(device)
|
|
|
|
|
|
|
211 |
else:
|
212 |
-
print(f"PIPE:{pipe}")
|
213 |
-
|
214 |
|
215 |
if args.cpu_offloading:
|
216 |
pipe.enable_sequential_cpu_offload()
|
217 |
|
218 |
-
#if pipe.device != torch.device('cuda'):
|
219 |
-
# pipe.to(device, dtype)
|
220 |
|
221 |
if args.enable_multi_apply:
|
222 |
|
@@ -232,8 +286,8 @@ def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pip
|
|
232 |
multi_apply_fn = None
|
233 |
|
234 |
torch.cuda.empty_cache() # Free up cached memory
|
235 |
-
|
236 |
-
|
237 |
|
238 |
init_latents = torch.randn(shape, device=device, dtype=dtype)
|
239 |
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
|
@@ -246,6 +300,28 @@ def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pip
|
|
246 |
best_image.save(f"{save_dir}/best_image.png")
|
247 |
#init_image.save(f"{save_dir}/init_image.png")
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
elif args.task == "example-prompts":
|
250 |
fo = open("assets/example_prompts.txt", "r")
|
251 |
prompts = fo.readlines()
|
|
|
15 |
from training import LatentNoiseTrainer, get_optimizer
|
16 |
|
17 |
|
18 |
+
def find_and_move_object_to_cpu():
|
19 |
+
for obj in gc.get_objects():
|
20 |
+
try:
|
21 |
+
# Check if the object is a PyTorch model
|
22 |
+
if isinstance(obj, torch.nn.Module):
|
23 |
+
# Check if any parameter of the model is on CUDA
|
24 |
+
if any(param.is_cuda for param in obj.parameters()):
|
25 |
+
print(f"Found PyTorch model on CUDA: {type(obj).__name__}")
|
26 |
+
# Move the model to CPU
|
27 |
+
obj.to('cpu')
|
28 |
+
print(f"Moved {type(obj).__name__} to CPU.")
|
29 |
+
|
30 |
+
# Optionally check if buffers are on CUDA
|
31 |
+
if any(buf.is_cuda for buf in obj.buffers()):
|
32 |
+
print(f"Found buffer on CUDA in {type(obj).__name__}")
|
33 |
+
obj.to('cpu')
|
34 |
+
print(f"Moved buffers of {type(obj).__name__} to CPU.")
|
35 |
+
|
36 |
+
except Exception as e:
|
37 |
+
# Handle any exceptions if obj is not a torch model
|
38 |
+
pass
|
39 |
+
|
40 |
|
41 |
def clear_gpu():
|
42 |
"""Clear GPU memory by removing tensors, freeing cache, and moving data to CPU."""
|
43 |
# List memory usage before clearing
|
44 |
print(f"Memory allocated before clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
|
45 |
print(f"Memory reserved before clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
|
|
|
|
|
|
|
46 |
|
47 |
# Move any bound tensors back to CPU if needed
|
48 |
if torch.cuda.is_available():
|
49 |
+
torch.cuda.empty_cache()
|
50 |
+
torch.cuda.synchronize() # Ensure that all operations are completed
|
51 |
+
print("GPU memory cleared.")
|
52 |
|
53 |
print(f"Memory allocated after clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
|
54 |
print(f"Memory reserved after clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
|
55 |
|
56 |
def unload_previous_model_if_needed(loaded_model_setup):
|
57 |
+
# Check if any GPU memory is being used even when loaded_model_setup is None
|
58 |
+
if loaded_model_setup is None:
|
59 |
+
if torch.cuda.is_available() and torch.cuda.memory_allocated() > 0:
|
60 |
+
print("Unknown model or tensors are still loaded on the GPU. Clearing GPU memory.")
|
61 |
+
# Call the function to find and move object to CPU
|
62 |
+
find_and_move_object_to_cpu()
|
63 |
+
|
64 |
+
return
|
65 |
+
|
66 |
"""Unload the current model from the GPU and free resources if a new model is being loaded."""
|
67 |
+
|
68 |
+
print("Unloading previous model from GPU to free memory.")
|
69 |
+
|
70 |
+
"""
|
71 |
+
previous_model = loaded_model_setup[7] # Assuming pipe is at position [7] in the setup
|
72 |
+
# If the model is 'hyper-sd', ensure its components are moved to CPU before deletion
|
73 |
+
if loaded_model_setup[0].model == "hyper-sd":
|
74 |
+
if previous_model.device == torch.device('cuda'):
|
75 |
+
if hasattr(previous_model, 'unet'):
|
76 |
+
print("Moving UNet back to CPU.")
|
77 |
+
previous_model.unet.to('cpu') # Move unet to CPU
|
78 |
+
|
79 |
+
print("Moving entire pipeline back to CPU.")
|
80 |
+
previous_model.to('cpu') # Move the entire pipeline (pipe) to CPU
|
81 |
+
# For other models, use a generic 'to' function if available
|
82 |
+
elif hasattr(previous_model, 'to') and loaded_model_setup[0].model != "flux":
|
83 |
+
if previous_model.device == torch.device('cuda'):
|
84 |
+
print("Moving previous model back to CPU.")
|
85 |
previous_model.to('cpu') # Move model to CPU to free GPU memory
|
86 |
+
|
87 |
+
# Delete the reference to the model to allow garbage collection
|
88 |
+
del previous_model
|
89 |
+
"""
|
90 |
+
# Call the function to find and move object to CPU
|
91 |
+
find_and_move_object_to_cpu()
|
92 |
+
|
93 |
+
# Clear GPU memory
|
94 |
+
clear_gpu() # Ensure that this function properly clears memory (e.g., torch.cuda.empty_cache())
|
95 |
|
96 |
def setup(args, loaded_model_setup=None):
|
97 |
seed_everything(args.seed)
|
|
|
192 |
|
193 |
# Final memory cleanup after model loading
|
194 |
torch.cuda.empty_cache()
|
|
|
195 |
|
196 |
trainer = LatentNoiseTrainer(
|
197 |
reward_losses=reward_losses,
|
|
|
235 |
|
236 |
# Final memory cleanup
|
237 |
torch.cuda.empty_cache() # Free up cached memory
|
238 |
+
|
239 |
|
240 |
|
241 |
|
|
|
247 |
def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback=None):
|
248 |
|
249 |
if args.task == "single":
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
# Attempt to move the model to GPU if model is not Flux
|
254 |
if args.model != "flux":
|
255 |
+
if args.model == "hyper-sd":
|
256 |
if pipe.device != torch.device('cuda'):
|
257 |
+
# Transfer UNet to GPU
|
258 |
+
pipe.unet = pipe.unet.to(device, dtype)
|
259 |
+
# Transfer the whole pipe to GPU, if required (optional)
|
260 |
+
pipe = pipe.to(device, dtype)
|
261 |
+
# upcast vae
|
262 |
+
pipe.vae = pipe.vae.to(dtype=torch.float32)
|
263 |
+
elif args.model == "pixart":
|
264 |
if pipe.device != torch.device('cuda'):
|
265 |
pipe.to(device)
|
266 |
+
else:
|
267 |
+
if pipe.device != torch.device('cuda'):
|
268 |
+
pipe.to(device, dtype)
|
269 |
else:
|
|
|
|
|
270 |
|
271 |
if args.cpu_offloading:
|
272 |
pipe.enable_sequential_cpu_offload()
|
273 |
|
|
|
|
|
274 |
|
275 |
if args.enable_multi_apply:
|
276 |
|
|
|
286 |
multi_apply_fn = None
|
287 |
|
288 |
torch.cuda.empty_cache() # Free up cached memory
|
289 |
+
|
290 |
+
print(f"PIPE:{pipe}")
|
291 |
|
292 |
init_latents = torch.randn(shape, device=device, dtype=dtype)
|
293 |
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
|
|
|
300 |
best_image.save(f"{save_dir}/best_image.png")
|
301 |
#init_image.save(f"{save_dir}/init_image.png")
|
302 |
|
303 |
+
# Move the pipe back to CPU
|
304 |
+
if args.model != "flux":
|
305 |
+
if args.model == "hyper-sd":
|
306 |
+
if pipe.device == torch.device('cuda'):
|
307 |
+
print("Moving the entire pipe back to CPU.")
|
308 |
+
# Transfer UNet to GPU
|
309 |
+
pipe.unet = pipe.unet.to("cpu")
|
310 |
+
pipe.to('cpu') # Move all components of the pipe back to CPU
|
311 |
+
# Delete the pipe to free resources
|
312 |
+
del pipe
|
313 |
+
print("Pipe deleted to free resources.")
|
314 |
+
|
315 |
+
else:
|
316 |
+
if pipe.device == torch.device('cuda'):
|
317 |
+
print("Moving the entire pipe back to CPU.")
|
318 |
+
pipe.to("cpu")
|
319 |
+
# Delete the pipe to free resources
|
320 |
+
del pipe
|
321 |
+
print("Pipe deleted to free resources.")
|
322 |
+
|
323 |
+
clear_gpu()
|
324 |
+
|
325 |
elif args.task == "example-prompts":
|
326 |
fo = open("assets/example_prompts.txt", "r")
|
327 |
prompts = fo.readlines()
|