File size: 5,004 Bytes
d2b7e94 02e90e4 d2b7e94 02e90e4 d2b7e94 02e90e4 c5458aa 02e90e4 bed01bd 02e90e4 21f455c 02e90e4 374f426 bed01bd d5d0921 1df74c6 02e90e4 374f426 02e90e4 d5d0921 374f426 02e90e4 374f426 02e90e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import logging
import sys
from functools import lru_cache
import torch
from modules import config
logger = logging.getLogger(__name__)
if sys.platform == "darwin":
from modules.devices import mac_devices
def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
return mac_devices.has_mps
def get_cuda_device_id():
return (
int(config.runtime_env_vars.device_id)
if config.runtime_env_vars.device_id is not None
and config.runtime_env_vars.device_id.isdigit()
else 0
) or torch.cuda.current_device()
def get_cuda_device_string():
if config.runtime_env_vars.device_id is not None:
return f"cuda:{config.runtime_env_vars.device_id}"
return "cuda"
def get_available_gpus() -> list[tuple[int, int]]:
"""
Get the list of available GPUs and their free memory.
:return: A list of tuples where each tuple contains (GPU index, free memory in bytes).
"""
available_gpus = []
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
free_memory = props.total_memory - torch.cuda.memory_reserved(i)
available_gpus.append((i, free_memory))
return available_gpus
def get_memory_available_gpus(min_memory=2048):
available_gpus = get_available_gpus()
memory_available_gpus = [
gpu for gpu, free_memory in available_gpus if free_memory > min_memory
]
return memory_available_gpus
def get_target_device_id_or_memory_available_gpu():
memory_available_gpus = get_memory_available_gpus()
device_id = get_cuda_device_id()
if device_id not in memory_available_gpus:
if len(memory_available_gpus) != 0:
logger.warning(
f"Device {device_id} is not available or does not have enough memory. will try to use {memory_available_gpus}"
)
config.runtime_env_vars.device_id = str(memory_available_gpus[0])
else:
logger.warning(
f"Device {device_id} is not available or does not have enough memory. Using CPU instead."
)
return "cpu"
return get_cuda_device_string()
def get_optimal_device_name():
if config.runtime_env_vars.use_cpu == "all":
return "cpu"
if torch.cuda.is_available():
return get_target_device_id_or_memory_available_gpu()
if has_mps():
return "mps"
return "cpu"
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task):
if (
task in config.runtime_env_vars.use_cpu
or "all" in config.runtime_env_vars.use_cpu
):
return cpu
return get_optimal_device()
def torch_gc():
try:
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if has_mps():
mac_devices.torch_mps_gc()
except Exception as e:
logger.error(f"Error in torch_gc", exc_info=True)
cpu: torch.device = torch.device("cpu")
device: torch.device = None
dtype: torch.dtype = torch.float32
dtype_dvae: torch.dtype = torch.float32
dtype_vocos: torch.dtype = torch.float32
dtype_gpt: torch.dtype = torch.float32
dtype_decoder: torch.dtype = torch.float32
def reset_device():
global device
global dtype
global dtype_dvae
global dtype_vocos
global dtype_gpt
global dtype_decoder
if config.runtime_env_vars.use_cpu is None:
config.runtime_env_vars.use_cpu = []
if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
logger.warning(
"Cannot use half precision with CPU, using full precision instead"
)
config.runtime_env_vars.no_half = True
if not config.runtime_env_vars.no_half:
dtype = torch.float16
dtype_dvae = torch.float16
dtype_vocos = torch.float16
dtype_gpt = torch.float16
dtype_decoder = torch.float16
logger.info("Using half precision: torch.float16")
else:
dtype = torch.float32
dtype_dvae = torch.float32
dtype_vocos = torch.float32
dtype_gpt = torch.float32
dtype_decoder = torch.float32
logger.info("Using full precision: torch.float32")
if "all" in config.runtime_env_vars.use_cpu:
device = cpu
else:
device = get_optimal_device()
logger.info(f"Using device: {device}")
@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
|