aifeifei798 commited on
Commit
3d07acf
1 Parent(s): 851bd03

Upload model_management.py

Browse files
Files changed (1) hide show
  1. ldm_patched/modules/model_management.py +791 -807
ldm_patched/modules/model_management.py CHANGED
@@ -1,807 +1,791 @@
1
- import psutil
2
- from enum import Enum
3
- from ldm_patched.modules.args_parser import args
4
- import ldm_patched.modules.utils
5
- import torch
6
- import sys
7
-
8
- class VRAMState(Enum):
9
- DISABLED = 0 #No vram present: no need to move models to vram
10
- NO_VRAM = 1 #Very low vram: enable all the options to save vram
11
- LOW_VRAM = 2
12
- NORMAL_VRAM = 3
13
- HIGH_VRAM = 4
14
- SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
15
-
16
- class CPUState(Enum):
17
- GPU = 0
18
- CPU = 1
19
- MPS = 2
20
-
21
- # Determine VRAM State
22
- vram_state = VRAMState.NORMAL_VRAM
23
- set_vram_to = VRAMState.NORMAL_VRAM
24
- cpu_state = CPUState.GPU
25
-
26
- total_vram = 0
27
-
28
- lowvram_available = True
29
- xpu_available = False
30
-
31
- if args.pytorch_deterministic:
32
- print("Using deterministic algorithms for pytorch")
33
- torch.use_deterministic_algorithms(True, warn_only=True)
34
-
35
- directml_enabled = False
36
- if args.directml is not None:
37
- import torch_directml
38
- directml_enabled = True
39
- device_index = args.directml
40
- if device_index < 0:
41
- directml_device = torch_directml.device()
42
- else:
43
- directml_device = torch_directml.device(device_index)
44
- print("Using directml with device:", torch_directml.device_name(device_index))
45
- # torch_directml.disable_tiled_resources(True)
46
- lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
47
-
48
- try:
49
- import intel_extension_for_pytorch as ipex
50
- if torch.xpu.is_available():
51
- xpu_available = True
52
- except:
53
- pass
54
-
55
- try:
56
- if torch.backends.mps.is_available():
57
- cpu_state = CPUState.MPS
58
- import torch.mps
59
- except:
60
- pass
61
-
62
- if args.always_cpu:
63
- if args.always_cpu > 0:
64
- torch.set_num_threads(args.always_cpu)
65
- print(f"Running on {torch.get_num_threads()} CPU threads")
66
- cpu_state = CPUState.CPU
67
-
68
- def is_intel_xpu():
69
- global cpu_state
70
- global xpu_available
71
- if cpu_state == CPUState.GPU:
72
- if xpu_available:
73
- return True
74
- return False
75
-
76
- def get_torch_device():
77
- global directml_enabled
78
- global cpu_state
79
- if directml_enabled:
80
- global directml_device
81
- return directml_device
82
- if cpu_state == CPUState.MPS:
83
- return torch.device("mps")
84
- if cpu_state == CPUState.CPU:
85
- return torch.device("cpu")
86
- else:
87
- if is_intel_xpu():
88
- return torch.device("xpu")
89
- else:
90
- return torch.device(torch.cuda.current_device())
91
-
92
- def get_total_memory(dev=None, torch_total_too=False):
93
- global directml_enabled
94
- if dev is None:
95
- dev = get_torch_device()
96
-
97
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
98
- mem_total = psutil.virtual_memory().total
99
- mem_total_torch = mem_total
100
- else:
101
- if directml_enabled:
102
- mem_total = 1024 * 1024 * 1024 #TODO
103
- mem_total_torch = mem_total
104
- elif is_intel_xpu():
105
- stats = torch.xpu.memory_stats(dev)
106
- mem_reserved = stats['reserved_bytes.all.current']
107
- mem_total = torch.xpu.get_device_properties(dev).total_memory
108
- mem_total_torch = mem_reserved
109
- else:
110
- stats = torch.cuda.memory_stats(dev)
111
- mem_reserved = stats['reserved_bytes.all.current']
112
- _, mem_total_cuda = torch.cuda.mem_get_info(dev)
113
- mem_total_torch = mem_reserved
114
- mem_total = mem_total_cuda
115
-
116
- if torch_total_too:
117
- return (mem_total, mem_total_torch)
118
- else:
119
- return mem_total
120
-
121
- total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
122
- total_ram = psutil.virtual_memory().total / (1024 * 1024)
123
- print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
124
- if not args.always_normal_vram and not args.always_cpu:
125
- if lowvram_available and total_vram <= 4096:
126
- print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --always-normal-vram")
127
- set_vram_to = VRAMState.LOW_VRAM
128
-
129
- try:
130
- OOM_EXCEPTION = torch.cuda.OutOfMemoryError
131
- except:
132
- OOM_EXCEPTION = Exception
133
-
134
- XFORMERS_VERSION = ""
135
- XFORMERS_ENABLED_VAE = True
136
- if args.disable_xformers:
137
- XFORMERS_IS_AVAILABLE = False
138
- else:
139
- try:
140
- import xformers
141
- import xformers.ops
142
- XFORMERS_IS_AVAILABLE = True
143
- try:
144
- XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
145
- except:
146
- pass
147
- try:
148
- XFORMERS_VERSION = xformers.version.__version__
149
- print("xformers version:", XFORMERS_VERSION)
150
- if XFORMERS_VERSION.startswith("0.0.18"):
151
- print()
152
- print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
153
- print("Please downgrade or upgrade xformers to a different version.")
154
- print()
155
- XFORMERS_ENABLED_VAE = False
156
- except:
157
- pass
158
- except:
159
- XFORMERS_IS_AVAILABLE = False
160
-
161
- def is_nvidia():
162
- global cpu_state
163
- if cpu_state == CPUState.GPU:
164
- if torch.version.cuda:
165
- return True
166
- return False
167
-
168
- ENABLE_PYTORCH_ATTENTION = False
169
- if args.attention_pytorch:
170
- ENABLE_PYTORCH_ATTENTION = True
171
- XFORMERS_IS_AVAILABLE = False
172
-
173
- VAE_DTYPE = torch.float32
174
-
175
- try:
176
- if is_nvidia():
177
- torch_version = torch.version.__version__
178
- if int(torch_version[0]) >= 2:
179
- if ENABLE_PYTORCH_ATTENTION == False and args.attention_split == False and args.attention_quad == False:
180
- ENABLE_PYTORCH_ATTENTION = True
181
- if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
182
- VAE_DTYPE = torch.bfloat16
183
- if is_intel_xpu():
184
- if args.attention_split == False and args.attention_quad == False:
185
- ENABLE_PYTORCH_ATTENTION = True
186
- except:
187
- pass
188
-
189
- if is_intel_xpu():
190
- VAE_DTYPE = torch.bfloat16
191
-
192
- if args.vae_in_cpu:
193
- VAE_DTYPE = torch.float32
194
-
195
- if args.vae_in_fp16:
196
- VAE_DTYPE = torch.float16
197
- elif args.vae_in_bf16:
198
- VAE_DTYPE = torch.bfloat16
199
- elif args.vae_in_fp32:
200
- VAE_DTYPE = torch.float32
201
-
202
-
203
- if ENABLE_PYTORCH_ATTENTION:
204
- torch.backends.cuda.enable_math_sdp(True)
205
- torch.backends.cuda.enable_flash_sdp(True)
206
- torch.backends.cuda.enable_mem_efficient_sdp(True)
207
-
208
- if args.always_low_vram:
209
- set_vram_to = VRAMState.LOW_VRAM
210
- lowvram_available = True
211
- elif args.always_no_vram:
212
- set_vram_to = VRAMState.NO_VRAM
213
- elif args.always_high_vram or args.always_gpu:
214
- vram_state = VRAMState.HIGH_VRAM
215
-
216
- FORCE_FP32 = False
217
- FORCE_FP16 = False
218
- if args.all_in_fp32:
219
- print("Forcing FP32, if this improves things please report it.")
220
- FORCE_FP32 = True
221
-
222
- if args.all_in_fp16:
223
- print("Forcing FP16.")
224
- FORCE_FP16 = True
225
-
226
- if lowvram_available:
227
- if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
228
- vram_state = set_vram_to
229
-
230
-
231
- if cpu_state != CPUState.GPU:
232
- vram_state = VRAMState.DISABLED
233
-
234
- if cpu_state == CPUState.MPS:
235
- vram_state = VRAMState.SHARED
236
-
237
- print(f"Set vram state to: {vram_state.name}")
238
-
239
- ALWAYS_VRAM_OFFLOAD = args.always_offload_from_vram
240
-
241
- if ALWAYS_VRAM_OFFLOAD:
242
- print("Always offload VRAM")
243
-
244
- def get_torch_device_name(device):
245
- if hasattr(device, 'type'):
246
- if device.type == "cuda":
247
- try:
248
- allocator_backend = torch.cuda.get_allocator_backend()
249
- except:
250
- allocator_backend = ""
251
- return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
252
- else:
253
- return "{}".format(device.type)
254
- elif is_intel_xpu():
255
- return "{} {}".format(device, torch.xpu.get_device_name(device))
256
- else:
257
- return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
258
-
259
- try:
260
- print("Device:", get_torch_device_name(get_torch_device()))
261
- except:
262
- print("Could not pick default device.")
263
-
264
- print("VAE dtype:", VAE_DTYPE)
265
-
266
- current_loaded_models = []
267
-
268
- def module_size(module):
269
- module_mem = 0
270
- sd = module.state_dict()
271
- for k in sd:
272
- t = sd[k]
273
- module_mem += t.nelement() * t.element_size()
274
- return module_mem
275
-
276
- class LoadedModel:
277
- def __init__(self, model):
278
- self.model = model
279
- self.model_accelerated = False
280
- self.device = model.load_device
281
-
282
- def model_memory(self):
283
- return self.model.model_size()
284
-
285
- def model_memory_required(self, device):
286
- if device == self.model.current_device:
287
- return 0
288
- else:
289
- return self.model_memory()
290
-
291
- def model_load(self, lowvram_model_memory=0):
292
- patch_model_to = None
293
- if lowvram_model_memory == 0:
294
- patch_model_to = self.device
295
-
296
- self.model.model_patches_to(self.device)
297
- self.model.model_patches_to(self.model.model_dtype())
298
-
299
- try:
300
- self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
301
- except Exception as e:
302
- self.model.unpatch_model(self.model.offload_device)
303
- self.model_unload()
304
- raise e
305
-
306
- if lowvram_model_memory > 0:
307
- print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
308
- mem_counter = 0
309
- for m in self.real_model.modules():
310
- if hasattr(m, "ldm_patched_cast_weights"):
311
- m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights
312
- m.ldm_patched_cast_weights = True
313
- module_mem = module_size(m)
314
- if mem_counter + module_mem < lowvram_model_memory:
315
- m.to(self.device)
316
- mem_counter += module_mem
317
- elif hasattr(m, "weight"): #only modules with ldm_patched_cast_weights can be set to lowvram mode
318
- m.to(self.device)
319
- mem_counter += module_size(m)
320
- print("lowvram: loaded module regularly", m)
321
-
322
- self.model_accelerated = True
323
-
324
- if is_intel_xpu() and not args.disable_ipex_hijack:
325
- self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
326
-
327
- return self.real_model
328
-
329
- def model_unload(self):
330
- if self.model_accelerated:
331
- for m in self.real_model.modules():
332
- if hasattr(m, "prev_ldm_patched_cast_weights"):
333
- m.ldm_patched_cast_weights = m.prev_ldm_patched_cast_weights
334
- del m.prev_ldm_patched_cast_weights
335
-
336
- self.model_accelerated = False
337
-
338
- self.model.unpatch_model(self.model.offload_device)
339
- self.model.model_patches_to(self.model.offload_device)
340
-
341
- def __eq__(self, other):
342
- return self.model is other.model
343
-
344
- def minimum_inference_memory():
345
- return (1024 * 1024 * 1024)
346
-
347
- def unload_model_clones(model):
348
- to_unload = []
349
- for i in range(len(current_loaded_models)):
350
- if model.is_clone(current_loaded_models[i].model):
351
- to_unload = [i] + to_unload
352
-
353
- for i in to_unload:
354
- print("unload clone", i)
355
- current_loaded_models.pop(i).model_unload()
356
-
357
- def free_memory(memory_required, device, keep_loaded=[]):
358
- unloaded_model = False
359
- for i in range(len(current_loaded_models) -1, -1, -1):
360
- if not ALWAYS_VRAM_OFFLOAD:
361
- if get_free_memory(device) > memory_required:
362
- break
363
- shift_model = current_loaded_models[i]
364
- if shift_model.device == device:
365
- if shift_model not in keep_loaded:
366
- m = current_loaded_models.pop(i)
367
- m.model_unload()
368
- del m
369
- unloaded_model = True
370
-
371
- if unloaded_model:
372
- soft_empty_cache()
373
- else:
374
- if vram_state != VRAMState.HIGH_VRAM:
375
- mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
376
- if mem_free_torch > mem_free_total * 0.25:
377
- soft_empty_cache()
378
-
379
- def load_models_gpu(models, memory_required=0):
380
- global vram_state
381
-
382
- inference_memory = minimum_inference_memory()
383
- extra_mem = max(inference_memory, memory_required)
384
-
385
- models_to_load = []
386
- models_already_loaded = []
387
- for x in models:
388
- loaded_model = LoadedModel(x)
389
-
390
- if loaded_model in current_loaded_models:
391
- index = current_loaded_models.index(loaded_model)
392
- current_loaded_models.insert(0, current_loaded_models.pop(index))
393
- models_already_loaded.append(loaded_model)
394
- else:
395
- if hasattr(x, "model"):
396
- print(f"Requested to load {x.model.__class__.__name__}")
397
- models_to_load.append(loaded_model)
398
-
399
- if len(models_to_load) == 0:
400
- devs = set(map(lambda a: a.device, models_already_loaded))
401
- for d in devs:
402
- if d != torch.device("cpu"):
403
- free_memory(extra_mem, d, models_already_loaded)
404
- return
405
-
406
- print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
407
-
408
- total_memory_required = {}
409
- for loaded_model in models_to_load:
410
- unload_model_clones(loaded_model.model)
411
- total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
412
-
413
- for device in total_memory_required:
414
- if device != torch.device("cpu"):
415
- free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
416
-
417
- for loaded_model in models_to_load:
418
- model = loaded_model.model
419
- torch_dev = model.load_device
420
- if is_device_cpu(torch_dev):
421
- vram_set_state = VRAMState.DISABLED
422
- else:
423
- vram_set_state = vram_state
424
- lowvram_model_memory = 0
425
- if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
426
- model_size = loaded_model.model_memory_required(torch_dev)
427
- current_free_mem = get_free_memory(torch_dev)
428
- lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
429
- if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
430
- vram_set_state = VRAMState.LOW_VRAM
431
- else:
432
- lowvram_model_memory = 0
433
-
434
- if vram_set_state == VRAMState.NO_VRAM:
435
- lowvram_model_memory = 64 * 1024 * 1024
436
-
437
- cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
438
- current_loaded_models.insert(0, loaded_model)
439
- return
440
-
441
-
442
- def load_model_gpu(model):
443
- return load_models_gpu([model])
444
-
445
- def cleanup_models():
446
- to_delete = []
447
- for i in range(len(current_loaded_models)):
448
- if sys.getrefcount(current_loaded_models[i].model) <= 2:
449
- to_delete = [i] + to_delete
450
-
451
- for i in to_delete:
452
- x = current_loaded_models.pop(i)
453
- x.model_unload()
454
- del x
455
-
456
- def dtype_size(dtype):
457
- dtype_size = 4
458
- if dtype == torch.float16 or dtype == torch.bfloat16:
459
- dtype_size = 2
460
- elif dtype == torch.float32:
461
- dtype_size = 4
462
- else:
463
- try:
464
- dtype_size = dtype.itemsize
465
- except: #Old pytorch doesn't have .itemsize
466
- pass
467
- return dtype_size
468
-
469
- def unet_offload_device():
470
- if vram_state == VRAMState.HIGH_VRAM:
471
- return get_torch_device()
472
- else:
473
- return torch.device("cpu")
474
-
475
- def unet_inital_load_device(parameters, dtype):
476
- torch_dev = get_torch_device()
477
- if vram_state == VRAMState.HIGH_VRAM:
478
- return torch_dev
479
-
480
- cpu_dev = torch.device("cpu")
481
- if ALWAYS_VRAM_OFFLOAD:
482
- return cpu_dev
483
-
484
- model_size = dtype_size(dtype) * parameters
485
-
486
- mem_dev = get_free_memory(torch_dev)
487
- mem_cpu = get_free_memory(cpu_dev)
488
- if mem_dev > mem_cpu and model_size < mem_dev:
489
- return torch_dev
490
- else:
491
- return cpu_dev
492
-
493
- def unet_dtype(device=None, model_params=0):
494
- if args.unet_in_bf16:
495
- return torch.bfloat16
496
- if args.unet_in_fp16:
497
- return torch.float16
498
- if args.unet_in_fp8_e4m3fn:
499
- return torch.float8_e4m3fn
500
- if args.unet_in_fp8_e5m2:
501
- return torch.float8_e5m2
502
- if should_use_fp16(device=device, model_params=model_params):
503
- return torch.float16
504
- return torch.float32
505
-
506
- # None means no manual cast
507
- def unet_manual_cast(weight_dtype, inference_device):
508
- if weight_dtype == torch.float32:
509
- return None
510
-
511
- fp16_supported = ldm_patched.modules.model_management.should_use_fp16(inference_device, prioritize_performance=False)
512
- if fp16_supported and weight_dtype == torch.float16:
513
- return None
514
-
515
- if fp16_supported:
516
- return torch.float16
517
- else:
518
- return torch.float32
519
-
520
- def text_encoder_offload_device():
521
- if args.always_gpu:
522
- return get_torch_device()
523
- else:
524
- return torch.device("cpu")
525
-
526
- def text_encoder_device():
527
- if args.always_gpu:
528
- return get_torch_device()
529
- elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
530
- if is_intel_xpu():
531
- return torch.device("cpu")
532
- if should_use_fp16(prioritize_performance=False):
533
- return get_torch_device()
534
- else:
535
- return torch.device("cpu")
536
- else:
537
- return torch.device("cpu")
538
-
539
- def text_encoder_dtype(device=None):
540
- if args.clip_in_fp8_e4m3fn:
541
- return torch.float8_e4m3fn
542
- elif args.clip_in_fp8_e5m2:
543
- return torch.float8_e5m2
544
- elif args.clip_in_fp16:
545
- return torch.float16
546
- elif args.clip_in_fp32:
547
- return torch.float32
548
-
549
- if is_device_cpu(device):
550
- return torch.float16
551
-
552
- if should_use_fp16(device, prioritize_performance=False):
553
- return torch.float16
554
- else:
555
- return torch.float32
556
-
557
- def intermediate_device():
558
- if args.always_gpu:
559
- return get_torch_device()
560
- else:
561
- return torch.device("cpu")
562
-
563
- def vae_device():
564
- if args.vae_in_cpu:
565
- return torch.device("cpu")
566
- return get_torch_device()
567
-
568
- def vae_offload_device():
569
- if args.always_gpu:
570
- return get_torch_device()
571
- else:
572
- return torch.device("cpu")
573
-
574
- def vae_dtype():
575
- global VAE_DTYPE
576
- return VAE_DTYPE
577
-
578
- def get_autocast_device(dev):
579
- if hasattr(dev, 'type'):
580
- return dev.type
581
- return "cuda"
582
-
583
- def supports_dtype(device, dtype): #TODO
584
- if dtype == torch.float32:
585
- return True
586
- if is_device_cpu(device):
587
- return False
588
- if dtype == torch.float16:
589
- return True
590
- if dtype == torch.bfloat16:
591
- return True
592
- return False
593
-
594
- def device_supports_non_blocking(device):
595
- if is_device_mps(device):
596
- return False #pytorch bug? mps doesn't support non blocking
597
- return True
598
-
599
- def cast_to_device(tensor, device, dtype, copy=False):
600
- device_supports_cast = False
601
- if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
602
- device_supports_cast = True
603
- elif tensor.dtype == torch.bfloat16:
604
- if hasattr(device, 'type') and device.type.startswith("cuda"):
605
- device_supports_cast = True
606
- elif is_intel_xpu():
607
- device_supports_cast = True
608
-
609
- non_blocking = device_supports_non_blocking(device)
610
-
611
- if device_supports_cast:
612
- if copy:
613
- if tensor.device == device:
614
- return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
615
- return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
616
- else:
617
- return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
618
- else:
619
- return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
620
-
621
- def xformers_enabled():
622
- global directml_enabled
623
- global cpu_state
624
- if cpu_state != CPUState.GPU:
625
- return False
626
- if is_intel_xpu():
627
- return False
628
- if directml_enabled:
629
- return False
630
- return XFORMERS_IS_AVAILABLE
631
-
632
-
633
- def xformers_enabled_vae():
634
- enabled = xformers_enabled()
635
- if not enabled:
636
- return False
637
-
638
- return XFORMERS_ENABLED_VAE
639
-
640
- def pytorch_attention_enabled():
641
- global ENABLE_PYTORCH_ATTENTION
642
- return ENABLE_PYTORCH_ATTENTION
643
-
644
- def pytorch_attention_flash_attention():
645
- global ENABLE_PYTORCH_ATTENTION
646
- if ENABLE_PYTORCH_ATTENTION:
647
- #TODO: more reliable way of checking for flash attention?
648
- if is_nvidia(): #pytorch flash attention only works on Nvidia
649
- return True
650
- return False
651
-
652
- def get_free_memory(dev=None, torch_free_too=False):
653
- global directml_enabled
654
- if dev is None:
655
- dev = get_torch_device()
656
-
657
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
658
- mem_free_total = psutil.virtual_memory().available
659
- mem_free_torch = mem_free_total
660
- else:
661
- if directml_enabled:
662
- mem_free_total = 1024 * 1024 * 1024 #TODO
663
- mem_free_torch = mem_free_total
664
- elif is_intel_xpu():
665
- stats = torch.xpu.memory_stats(dev)
666
- mem_active = stats['active_bytes.all.current']
667
- mem_allocated = stats['allocated_bytes.all.current']
668
- mem_reserved = stats['reserved_bytes.all.current']
669
- mem_free_torch = mem_reserved - mem_active
670
- mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
671
- else:
672
- stats = torch.cuda.memory_stats(dev)
673
- mem_active = stats['active_bytes.all.current']
674
- mem_reserved = stats['reserved_bytes.all.current']
675
- mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
676
- mem_free_torch = mem_reserved - mem_active
677
- mem_free_total = mem_free_cuda + mem_free_torch
678
-
679
- if torch_free_too:
680
- return (mem_free_total, mem_free_torch)
681
- else:
682
- return mem_free_total
683
-
684
- def cpu_mode():
685
- global cpu_state
686
- return cpu_state == CPUState.CPU
687
-
688
- def mps_mode():
689
- global cpu_state
690
- return cpu_state == CPUState.MPS
691
-
692
- def is_device_cpu(device):
693
- if hasattr(device, 'type'):
694
- if (device.type == 'cpu'):
695
- return True
696
- return False
697
-
698
- def is_device_mps(device):
699
- if hasattr(device, 'type'):
700
- if (device.type == 'mps'):
701
- return True
702
- return False
703
-
704
- def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
705
- global directml_enabled
706
-
707
- if device is not None:
708
- if is_device_cpu(device):
709
- return False
710
-
711
- if FORCE_FP16:
712
- return True
713
-
714
- if device is not None: #TODO
715
- if is_device_mps(device):
716
- return False
717
-
718
- if FORCE_FP32:
719
- return False
720
-
721
- if directml_enabled:
722
- return False
723
-
724
- if cpu_mode() or mps_mode():
725
- return False #TODO ?
726
-
727
- if is_intel_xpu():
728
- return True
729
-
730
- if torch.cuda.is_bf16_supported():
731
- return True
732
-
733
- props = torch.cuda.get_device_properties("cuda")
734
- if props.major < 6:
735
- return False
736
-
737
- fp16_works = False
738
- #FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
739
- #when the model doesn't actually fit on the card
740
- #TODO: actually test if GP106 and others have the same type of behavior
741
- nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"]
742
- for x in nvidia_10_series:
743
- if x in props.name.lower():
744
- fp16_works = True
745
-
746
- if fp16_works:
747
- free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
748
- if (not prioritize_performance) or model_params * 4 > free_model_memory:
749
- return True
750
-
751
- if props.major < 7:
752
- return False
753
-
754
- #FP16 is just broken on these cards
755
- nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
756
- for x in nvidia_16_series:
757
- if x in props.name:
758
- return False
759
-
760
- return True
761
-
762
- def soft_empty_cache(force=False):
763
- global cpu_state
764
- if cpu_state == CPUState.MPS:
765
- torch.mps.empty_cache()
766
- elif is_intel_xpu():
767
- torch.xpu.empty_cache()
768
- elif torch.cuda.is_available():
769
- if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
770
- torch.cuda.empty_cache()
771
- torch.cuda.ipc_collect()
772
-
773
- def unload_all_models():
774
- free_memory(1e30, get_torch_device())
775
-
776
-
777
- def resolve_lowvram_weight(weight, model, key): #TODO: remove
778
- return weight
779
-
780
- #TODO: might be cleaner to put this somewhere else
781
- import threading
782
-
783
- class InterruptProcessingException(Exception):
784
- pass
785
-
786
- interrupt_processing_mutex = threading.RLock()
787
-
788
- interrupt_processing = False
789
- def interrupt_current_processing(value=True):
790
- global interrupt_processing
791
- global interrupt_processing_mutex
792
- with interrupt_processing_mutex:
793
- interrupt_processing = value
794
-
795
- def processing_interrupted():
796
- global interrupt_processing
797
- global interrupt_processing_mutex
798
- with interrupt_processing_mutex:
799
- return interrupt_processing
800
-
801
- def throw_exception_if_processing_interrupted():
802
- global interrupt_processing
803
- global interrupt_processing_mutex
804
- with interrupt_processing_mutex:
805
- if interrupt_processing:
806
- interrupt_processing = False
807
- raise InterruptProcessingException()
 
1
+ import psutil
2
+ from enum import Enum
3
+ from ldm_patched.modules.args_parser import args
4
+ import ldm_patched.modules.utils
5
+ import torch
6
+ import sys
7
+
8
+ class VRAMState(Enum):
9
+ DISABLED = 0 #No vram present: no need to move models to vram
10
+ NO_VRAM = 1 #Very low vram: enable all the options to save vram
11
+ LOW_VRAM = 2
12
+ NORMAL_VRAM = 3
13
+ HIGH_VRAM = 4
14
+ SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
15
+
16
+ class CPUState(Enum):
17
+ GPU = 0
18
+ CPU = 1
19
+ MPS = 2
20
+
21
+ # Determine VRAM State
22
+ vram_state = VRAMState.NORMAL_VRAM
23
+ set_vram_to = VRAMState.NORMAL_VRAM
24
+ cpu_state = CPUState.GPU
25
+
26
+ total_vram = 0
27
+
28
+ lowvram_available = True
29
+ xpu_available = False
30
+
31
+ if args.pytorch_deterministic:
32
+ print("Using deterministic algorithms for pytorch")
33
+ torch.use_deterministic_algorithms(True, warn_only=True)
34
+
35
+ directml_enabled = False
36
+ if args.directml is not None:
37
+ import torch_directml
38
+ directml_enabled = True
39
+ device_index = args.directml
40
+ if device_index < 0:
41
+ directml_device = torch_directml.device()
42
+ else:
43
+ directml_device = torch_directml.device(device_index)
44
+ print("Using directml with device:", torch_directml.device_name(device_index))
45
+ # torch_directml.disable_tiled_resources(True)
46
+ lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
47
+
48
+ try:
49
+ import intel_extension_for_pytorch as ipex
50
+ if torch.xpu.is_available():
51
+ xpu_available = True
52
+ except:
53
+ pass
54
+
55
+ try:
56
+ if torch.backends.mps.is_available():
57
+ cpu_state = CPUState.MPS
58
+ import torch.mps
59
+ except:
60
+ pass
61
+
62
+ if args.always_cpu:
63
+ if args.always_cpu > 0:
64
+ torch.set_num_threads(args.always_cpu)
65
+ print(f"Running on {torch.get_num_threads()} CPU threads")
66
+ cpu_state = CPUState.CPU
67
+
68
+ def is_intel_xpu():
69
+ global cpu_state
70
+ global xpu_available
71
+ if cpu_state == CPUState.GPU:
72
+ if xpu_available:
73
+ return True
74
+ return False
75
+
76
+ def get_torch_device():
77
+ global directml_enabled
78
+ global cpu_state
79
+ if directml_enabled:
80
+ global directml_device
81
+ return directml_device
82
+ if cpu_state == CPUState.MPS:
83
+ return torch.device("mps")
84
+ if cpu_state == CPUState.CPU:
85
+ return torch.device("cpu")
86
+ else:
87
+ if is_intel_xpu():
88
+ return torch.device("xpu")
89
+ else:
90
+ return torch.device(torch.cuda.current_device())
91
+
92
+ def get_total_memory(dev=None, torch_total_too=False):
93
+ global directml_enabled
94
+ if dev is None:
95
+ dev = get_torch_device()
96
+
97
+ stats = torch.cuda.memory_stats(dev)
98
+ mem_reserved = stats['reserved_bytes.all.current']
99
+ _, mem_total_cuda = torch.cuda.mem_get_info(dev)
100
+ mem_total_torch = mem_reserved
101
+ mem_total = mem_total_cuda
102
+
103
+ return mem_total
104
+
105
+ #total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
106
+ #total_ram = psutil.virtual_memory().total / (1024 * 1024)
107
+ #print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
108
+ #if not args.always_normal_vram and not args.always_cpu:
109
+ # if lowvram_available and total_vram <= 4096:
110
+ # print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --always-normal-vram")
111
+ # set_vram_to = VRAMState.LOW_VRAM
112
+ #
113
+ #try:
114
+ # OOM_EXCEPTION = torch.cuda.OutOfMemoryError
115
+ #except:
116
+ # OOM_EXCEPTION = Exception
117
+
118
+ XFORMERS_VERSION = ""
119
+ XFORMERS_ENABLED_VAE = True
120
+ if args.disable_xformers:
121
+ XFORMERS_IS_AVAILABLE = False
122
+ else:
123
+ try:
124
+ import xformers
125
+ import xformers.ops
126
+ XFORMERS_IS_AVAILABLE = True
127
+ try:
128
+ XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
129
+ except:
130
+ pass
131
+ try:
132
+ XFORMERS_VERSION = xformers.version.__version__
133
+ print("xformers version:", XFORMERS_VERSION)
134
+ if XFORMERS_VERSION.startswith("0.0.18"):
135
+ print()
136
+ print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
137
+ print("Please downgrade or upgrade xformers to a different version.")
138
+ print()
139
+ XFORMERS_ENABLED_VAE = False
140
+ except:
141
+ pass
142
+ except:
143
+ XFORMERS_IS_AVAILABLE = False
144
+
145
+ def is_nvidia():
146
+ global cpu_state
147
+ if cpu_state == CPUState.GPU:
148
+ if torch.version.cuda:
149
+ return True
150
+ return False
151
+
152
+ ENABLE_PYTORCH_ATTENTION = False
153
+ if args.attention_pytorch:
154
+ ENABLE_PYTORCH_ATTENTION = True
155
+ XFORMERS_IS_AVAILABLE = False
156
+
157
+ VAE_DTYPE = torch.float32
158
+
159
+ try:
160
+ if is_nvidia():
161
+ torch_version = torch.version.__version__
162
+ if int(torch_version[0]) >= 2:
163
+ if ENABLE_PYTORCH_ATTENTION == False and args.attention_split == False and args.attention_quad == False:
164
+ ENABLE_PYTORCH_ATTENTION = True
165
+ if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
166
+ VAE_DTYPE = torch.bfloat16
167
+ if is_intel_xpu():
168
+ if args.attention_split == False and args.attention_quad == False:
169
+ ENABLE_PYTORCH_ATTENTION = True
170
+ except:
171
+ pass
172
+
173
+ if is_intel_xpu():
174
+ VAE_DTYPE = torch.bfloat16
175
+
176
+ if args.vae_in_cpu:
177
+ VAE_DTYPE = torch.float32
178
+
179
+ if args.vae_in_fp16:
180
+ VAE_DTYPE = torch.float16
181
+ elif args.vae_in_bf16:
182
+ VAE_DTYPE = torch.bfloat16
183
+ elif args.vae_in_fp32:
184
+ VAE_DTYPE = torch.float32
185
+
186
+
187
+ if ENABLE_PYTORCH_ATTENTION:
188
+ torch.backends.cuda.enable_math_sdp(True)
189
+ torch.backends.cuda.enable_flash_sdp(True)
190
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
191
+
192
+ if args.always_low_vram:
193
+ set_vram_to = VRAMState.LOW_VRAM
194
+ lowvram_available = True
195
+ elif args.always_no_vram:
196
+ set_vram_to = VRAMState.NO_VRAM
197
+ elif args.always_high_vram or args.always_gpu:
198
+ vram_state = VRAMState.HIGH_VRAM
199
+
200
+ FORCE_FP32 = False
201
+ FORCE_FP16 = False
202
+ if args.all_in_fp32:
203
+ print("Forcing FP32, if this improves things please report it.")
204
+ FORCE_FP32 = True
205
+
206
+ if args.all_in_fp16:
207
+ print("Forcing FP16.")
208
+ FORCE_FP16 = True
209
+
210
+ if lowvram_available:
211
+ if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
212
+ vram_state = set_vram_to
213
+
214
+
215
+ if cpu_state != CPUState.GPU:
216
+ vram_state = VRAMState.DISABLED
217
+
218
+ if cpu_state == CPUState.MPS:
219
+ vram_state = VRAMState.SHARED
220
+
221
+ print(f"Set vram state to: {vram_state.name}")
222
+
223
+ ALWAYS_VRAM_OFFLOAD = args.always_offload_from_vram
224
+
225
+ if ALWAYS_VRAM_OFFLOAD:
226
+ print("Always offload VRAM")
227
+
228
+ def get_torch_device_name(device):
229
+ if hasattr(device, 'type'):
230
+ if device.type == "cuda":
231
+ try:
232
+ allocator_backend = torch.cuda.get_allocator_backend()
233
+ except:
234
+ allocator_backend = ""
235
+ return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
236
+ else:
237
+ return "{}".format(device.type)
238
+ elif is_intel_xpu():
239
+ return "{} {}".format(device, torch.xpu.get_device_name(device))
240
+ else:
241
+ return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
242
+
243
+ try:
244
+ print("Device:", get_torch_device_name(get_torch_device()))
245
+ except:
246
+ print("Could not pick default device.")
247
+
248
+ print("VAE dtype:", VAE_DTYPE)
249
+
250
+ current_loaded_models = []
251
+
252
+ def module_size(module):
253
+ module_mem = 0
254
+ sd = module.state_dict()
255
+ for k in sd:
256
+ t = sd[k]
257
+ module_mem += t.nelement() * t.element_size()
258
+ return module_mem
259
+
260
+ class LoadedModel:
261
+ def __init__(self, model):
262
+ self.model = model
263
+ self.model_accelerated = False
264
+ self.device = model.load_device
265
+
266
+ def model_memory(self):
267
+ return self.model.model_size()
268
+
269
+ def model_memory_required(self, device):
270
+ if device == self.model.current_device:
271
+ return 0
272
+ else:
273
+ return self.model_memory()
274
+
275
+ def model_load(self, lowvram_model_memory=0):
276
+ patch_model_to = None
277
+ if lowvram_model_memory == 0:
278
+ patch_model_to = self.device
279
+
280
+ self.model.model_patches_to(self.device)
281
+ self.model.model_patches_to(self.model.model_dtype())
282
+
283
+ try:
284
+ self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
285
+ except Exception as e:
286
+ self.model.unpatch_model(self.model.offload_device)
287
+ self.model_unload()
288
+ raise e
289
+
290
+ if lowvram_model_memory > 0:
291
+ print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
292
+ mem_counter = 0
293
+ for m in self.real_model.modules():
294
+ if hasattr(m, "ldm_patched_cast_weights"):
295
+ m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights
296
+ m.ldm_patched_cast_weights = True
297
+ module_mem = module_size(m)
298
+ if mem_counter + module_mem < lowvram_model_memory:
299
+ m.to(self.device)
300
+ mem_counter += module_mem
301
+ elif hasattr(m, "weight"): #only modules with ldm_patched_cast_weights can be set to lowvram mode
302
+ m.to(self.device)
303
+ mem_counter += module_size(m)
304
+ print("lowvram: loaded module regularly", m)
305
+
306
+ self.model_accelerated = True
307
+
308
+ if is_intel_xpu() and not args.disable_ipex_hijack:
309
+ self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
310
+
311
+ return self.real_model
312
+
313
+ def model_unload(self):
314
+ if self.model_accelerated:
315
+ for m in self.real_model.modules():
316
+ if hasattr(m, "prev_ldm_patched_cast_weights"):
317
+ m.ldm_patched_cast_weights = m.prev_ldm_patched_cast_weights
318
+ del m.prev_ldm_patched_cast_weights
319
+
320
+ self.model_accelerated = False
321
+
322
+ self.model.unpatch_model(self.model.offload_device)
323
+ self.model.model_patches_to(self.model.offload_device)
324
+
325
+ def __eq__(self, other):
326
+ return self.model is other.model
327
+
328
+ def minimum_inference_memory():
329
+ return (1024 * 1024 * 1024)
330
+
331
+ def unload_model_clones(model):
332
+ to_unload = []
333
+ for i in range(len(current_loaded_models)):
334
+ if model.is_clone(current_loaded_models[i].model):
335
+ to_unload = [i] + to_unload
336
+
337
+ for i in to_unload:
338
+ print("unload clone", i)
339
+ current_loaded_models.pop(i).model_unload()
340
+
341
+ def free_memory(memory_required, device, keep_loaded=[]):
342
+ unloaded_model = False
343
+ for i in range(len(current_loaded_models) -1, -1, -1):
344
+ if not ALWAYS_VRAM_OFFLOAD:
345
+ if get_free_memory(device) > memory_required:
346
+ break
347
+ shift_model = current_loaded_models[i]
348
+ if shift_model.device == device:
349
+ if shift_model not in keep_loaded:
350
+ m = current_loaded_models.pop(i)
351
+ m.model_unload()
352
+ del m
353
+ unloaded_model = True
354
+
355
+ if unloaded_model:
356
+ soft_empty_cache()
357
+ else:
358
+ if vram_state != VRAMState.HIGH_VRAM:
359
+ mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
360
+ if mem_free_torch > mem_free_total * 0.25:
361
+ soft_empty_cache()
362
+
363
+ def load_models_gpu(models, memory_required=0):
364
+ global vram_state
365
+
366
+ inference_memory = minimum_inference_memory()
367
+ extra_mem = max(inference_memory, memory_required)
368
+
369
+ models_to_load = []
370
+ models_already_loaded = []
371
+ for x in models:
372
+ loaded_model = LoadedModel(x)
373
+
374
+ if loaded_model in current_loaded_models:
375
+ index = current_loaded_models.index(loaded_model)
376
+ current_loaded_models.insert(0, current_loaded_models.pop(index))
377
+ models_already_loaded.append(loaded_model)
378
+ else:
379
+ if hasattr(x, "model"):
380
+ print(f"Requested to load {x.model.__class__.__name__}")
381
+ models_to_load.append(loaded_model)
382
+
383
+ if len(models_to_load) == 0:
384
+ devs = set(map(lambda a: a.device, models_already_loaded))
385
+ for d in devs:
386
+ if d != torch.device("cpu"):
387
+ free_memory(extra_mem, d, models_already_loaded)
388
+ return
389
+
390
+ print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
391
+
392
+ total_memory_required = {}
393
+ for loaded_model in models_to_load:
394
+ unload_model_clones(loaded_model.model)
395
+ total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
396
+
397
+ for device in total_memory_required:
398
+ if device != torch.device("cpu"):
399
+ free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
400
+
401
+ for loaded_model in models_to_load:
402
+ model = loaded_model.model
403
+ torch_dev = model.load_device
404
+ if is_device_cpu(torch_dev):
405
+ vram_set_state = VRAMState.DISABLED
406
+ else:
407
+ vram_set_state = vram_state
408
+ lowvram_model_memory = 0
409
+ if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
410
+ model_size = loaded_model.model_memory_required(torch_dev)
411
+ current_free_mem = get_free_memory(torch_dev)
412
+ lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
413
+ if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
414
+ vram_set_state = VRAMState.LOW_VRAM
415
+ else:
416
+ lowvram_model_memory = 0
417
+
418
+ if vram_set_state == VRAMState.NO_VRAM:
419
+ lowvram_model_memory = 64 * 1024 * 1024
420
+
421
+ cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
422
+ current_loaded_models.insert(0, loaded_model)
423
+ return
424
+
425
+
426
+ def load_model_gpu(model):
427
+ return load_models_gpu([model])
428
+
429
+ def cleanup_models():
430
+ to_delete = []
431
+ for i in range(len(current_loaded_models)):
432
+ if sys.getrefcount(current_loaded_models[i].model) <= 2:
433
+ to_delete = [i] + to_delete
434
+
435
+ for i in to_delete:
436
+ x = current_loaded_models.pop(i)
437
+ x.model_unload()
438
+ del x
439
+
440
+ def dtype_size(dtype):
441
+ dtype_size = 4
442
+ if dtype == torch.float16 or dtype == torch.bfloat16:
443
+ dtype_size = 2
444
+ elif dtype == torch.float32:
445
+ dtype_size = 4
446
+ else:
447
+ try:
448
+ dtype_size = dtype.itemsize
449
+ except: #Old pytorch doesn't have .itemsize
450
+ pass
451
+ return dtype_size
452
+
453
+ def unet_offload_device():
454
+ if vram_state == VRAMState.HIGH_VRAM:
455
+ return get_torch_device()
456
+ else:
457
+ return torch.device("cpu")
458
+
459
+ def unet_inital_load_device(parameters, dtype):
460
+ torch_dev = get_torch_device()
461
+ if vram_state == VRAMState.HIGH_VRAM:
462
+ return torch_dev
463
+
464
+ cpu_dev = torch.device("cpu")
465
+ if ALWAYS_VRAM_OFFLOAD:
466
+ return cpu_dev
467
+
468
+ model_size = dtype_size(dtype) * parameters
469
+
470
+ mem_dev = get_free_memory(torch_dev)
471
+ mem_cpu = get_free_memory(cpu_dev)
472
+ if mem_dev > mem_cpu and model_size < mem_dev:
473
+ return torch_dev
474
+ else:
475
+ return cpu_dev
476
+
477
+ def unet_dtype(device=None, model_params=0):
478
+ if args.unet_in_bf16:
479
+ return torch.bfloat16
480
+ if args.unet_in_fp16:
481
+ return torch.float16
482
+ if args.unet_in_fp8_e4m3fn:
483
+ return torch.float8_e4m3fn
484
+ if args.unet_in_fp8_e5m2:
485
+ return torch.float8_e5m2
486
+ if should_use_fp16(device=device, model_params=model_params):
487
+ return torch.float16
488
+ return torch.float32
489
+
490
+ # None means no manual cast
491
+ def unet_manual_cast(weight_dtype, inference_device):
492
+ if weight_dtype == torch.float32:
493
+ return None
494
+
495
+ fp16_supported = ldm_patched.modules.model_management.should_use_fp16(inference_device, prioritize_performance=False)
496
+ if fp16_supported and weight_dtype == torch.float16:
497
+ return None
498
+
499
+ if fp16_supported:
500
+ return torch.float16
501
+ else:
502
+ return torch.float32
503
+
504
+ def text_encoder_offload_device():
505
+ if args.always_gpu:
506
+ return get_torch_device()
507
+ else:
508
+ return torch.device("cpu")
509
+
510
+ def text_encoder_device():
511
+ if args.always_gpu:
512
+ return get_torch_device()
513
+ elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
514
+ if is_intel_xpu():
515
+ return torch.device("cpu")
516
+ if should_use_fp16(prioritize_performance=False):
517
+ return get_torch_device()
518
+ else:
519
+ return torch.device("cpu")
520
+ else:
521
+ return torch.device("cpu")
522
+
523
+ def text_encoder_dtype(device=None):
524
+ if args.clip_in_fp8_e4m3fn:
525
+ return torch.float8_e4m3fn
526
+ elif args.clip_in_fp8_e5m2:
527
+ return torch.float8_e5m2
528
+ elif args.clip_in_fp16:
529
+ return torch.float16
530
+ elif args.clip_in_fp32:
531
+ return torch.float32
532
+
533
+ if is_device_cpu(device):
534
+ return torch.float16
535
+
536
+ if should_use_fp16(device, prioritize_performance=False):
537
+ return torch.float16
538
+ else:
539
+ return torch.float32
540
+
541
+ def intermediate_device():
542
+ if args.always_gpu:
543
+ return get_torch_device()
544
+ else:
545
+ return torch.device("cpu")
546
+
547
+ def vae_device():
548
+ if args.vae_in_cpu:
549
+ return torch.device("cpu")
550
+ return get_torch_device()
551
+
552
+ def vae_offload_device():
553
+ if args.always_gpu:
554
+ return get_torch_device()
555
+ else:
556
+ return torch.device("cpu")
557
+
558
+ def vae_dtype():
559
+ global VAE_DTYPE
560
+ return VAE_DTYPE
561
+
562
+ def get_autocast_device(dev):
563
+ if hasattr(dev, 'type'):
564
+ return dev.type
565
+ return "cuda"
566
+
567
+ def supports_dtype(device, dtype): #TODO
568
+ if dtype == torch.float32:
569
+ return True
570
+ if is_device_cpu(device):
571
+ return False
572
+ if dtype == torch.float16:
573
+ return True
574
+ if dtype == torch.bfloat16:
575
+ return True
576
+ return False
577
+
578
+ def device_supports_non_blocking(device):
579
+ if is_device_mps(device):
580
+ return False #pytorch bug? mps doesn't support non blocking
581
+ return True
582
+
583
+ def cast_to_device(tensor, device, dtype, copy=False):
584
+ device_supports_cast = False
585
+ if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
586
+ device_supports_cast = True
587
+ elif tensor.dtype == torch.bfloat16:
588
+ if hasattr(device, 'type') and device.type.startswith("cuda"):
589
+ device_supports_cast = True
590
+ elif is_intel_xpu():
591
+ device_supports_cast = True
592
+
593
+ non_blocking = device_supports_non_blocking(device)
594
+
595
+ if device_supports_cast:
596
+ if copy:
597
+ if tensor.device == device:
598
+ return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
599
+ return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
600
+ else:
601
+ return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
602
+ else:
603
+ return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
604
+
605
+ def xformers_enabled():
606
+ global directml_enabled
607
+ global cpu_state
608
+ if cpu_state != CPUState.GPU:
609
+ return False
610
+ if is_intel_xpu():
611
+ return False
612
+ if directml_enabled:
613
+ return False
614
+ return XFORMERS_IS_AVAILABLE
615
+
616
+
617
+ def xformers_enabled_vae():
618
+ enabled = xformers_enabled()
619
+ if not enabled:
620
+ return False
621
+
622
+ return XFORMERS_ENABLED_VAE
623
+
624
+ def pytorch_attention_enabled():
625
+ global ENABLE_PYTORCH_ATTENTION
626
+ return ENABLE_PYTORCH_ATTENTION
627
+
628
+ def pytorch_attention_flash_attention():
629
+ global ENABLE_PYTORCH_ATTENTION
630
+ if ENABLE_PYTORCH_ATTENTION:
631
+ #TODO: more reliable way of checking for flash attention?
632
+ if is_nvidia(): #pytorch flash attention only works on Nvidia
633
+ return True
634
+ return False
635
+
636
+ def get_free_memory(dev=None, torch_free_too=False):
637
+ global directml_enabled
638
+ if dev is None:
639
+ dev = get_torch_device()
640
+
641
+ if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
642
+ mem_free_total = psutil.virtual_memory().available
643
+ mem_free_torch = mem_free_total
644
+ else:
645
+ if directml_enabled:
646
+ mem_free_total = 1024 * 1024 * 1024 #TODO
647
+ mem_free_torch = mem_free_total
648
+ elif is_intel_xpu():
649
+ stats = torch.xpu.memory_stats(dev)
650
+ mem_active = stats['active_bytes.all.current']
651
+ mem_allocated = stats['allocated_bytes.all.current']
652
+ mem_reserved = stats['reserved_bytes.all.current']
653
+ mem_free_torch = mem_reserved - mem_active
654
+ mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
655
+ else:
656
+ stats = torch.cuda.memory_stats(dev)
657
+ mem_active = stats['active_bytes.all.current']
658
+ mem_reserved = stats['reserved_bytes.all.current']
659
+ mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
660
+ mem_free_torch = mem_reserved - mem_active
661
+ mem_free_total = mem_free_cuda + mem_free_torch
662
+
663
+ if torch_free_too:
664
+ return (mem_free_total, mem_free_torch)
665
+ else:
666
+ return mem_free_total
667
+
668
+ def cpu_mode():
669
+ global cpu_state
670
+ return cpu_state == CPUState.CPU
671
+
672
+ def mps_mode():
673
+ global cpu_state
674
+ return cpu_state == CPUState.MPS
675
+
676
+ def is_device_cpu(device):
677
+ if hasattr(device, 'type'):
678
+ if (device.type == 'cpu'):
679
+ return True
680
+ return False
681
+
682
+ def is_device_mps(device):
683
+ if hasattr(device, 'type'):
684
+ if (device.type == 'mps'):
685
+ return True
686
+ return False
687
+
688
+ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
689
+ global directml_enabled
690
+
691
+ if device is not None:
692
+ if is_device_cpu(device):
693
+ return False
694
+
695
+ if FORCE_FP16:
696
+ return True
697
+
698
+ if device is not None: #TODO
699
+ if is_device_mps(device):
700
+ return False
701
+
702
+ if FORCE_FP32:
703
+ return False
704
+
705
+ if directml_enabled:
706
+ return False
707
+
708
+ if cpu_mode() or mps_mode():
709
+ return False #TODO ?
710
+
711
+ if is_intel_xpu():
712
+ return True
713
+
714
+ if torch.cuda.is_bf16_supported():
715
+ return True
716
+
717
+ props = torch.cuda.get_device_properties("cuda")
718
+ if props.major < 6:
719
+ return False
720
+
721
+ fp16_works = False
722
+ #FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
723
+ #when the model doesn't actually fit on the card
724
+ #TODO: actually test if GP106 and others have the same type of behavior
725
+ nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"]
726
+ for x in nvidia_10_series:
727
+ if x in props.name.lower():
728
+ fp16_works = True
729
+
730
+ if fp16_works:
731
+ free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
732
+ if (not prioritize_performance) or model_params * 4 > free_model_memory:
733
+ return True
734
+
735
+ if props.major < 7:
736
+ return False
737
+
738
+ #FP16 is just broken on these cards
739
+ nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
740
+ for x in nvidia_16_series:
741
+ if x in props.name:
742
+ return False
743
+
744
+ return True
745
+
746
+ def soft_empty_cache(force=False):
747
+ global cpu_state
748
+ if cpu_state == CPUState.MPS:
749
+ torch.mps.empty_cache()
750
+ elif is_intel_xpu():
751
+ torch.xpu.empty_cache()
752
+ elif torch.cuda.is_available():
753
+ if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
754
+ torch.cuda.empty_cache()
755
+ torch.cuda.ipc_collect()
756
+
757
+ def unload_all_models():
758
+ free_memory(1e30, get_torch_device())
759
+
760
+
761
+ def resolve_lowvram_weight(weight, model, key): #TODO: remove
762
+ return weight
763
+
764
+ #TODO: might be cleaner to put this somewhere else
765
+ import threading
766
+
767
+ class InterruptProcessingException(Exception):
768
+ pass
769
+
770
+ interrupt_processing_mutex = threading.RLock()
771
+
772
+ interrupt_processing = False
773
+ def interrupt_current_processing(value=True):
774
+ global interrupt_processing
775
+ global interrupt_processing_mutex
776
+ with interrupt_processing_mutex:
777
+ interrupt_processing = value
778
+
779
+ def processing_interrupted():
780
+ global interrupt_processing
781
+ global interrupt_processing_mutex
782
+ with interrupt_processing_mutex:
783
+ return interrupt_processing
784
+
785
+ def throw_exception_if_processing_interrupted():
786
+ global interrupt_processing
787
+ global interrupt_processing_mutex
788
+ with interrupt_processing_mutex:
789
+ if interrupt_processing:
790
+ interrupt_processing = False
791
+ raise InterruptProcessingException()