chattts / modules /devices /devices.py
zhzluke96
update
bed01bd
raw
history blame
5 kB
import logging
import sys
from functools import lru_cache
import torch
from modules import config
logger = logging.getLogger(__name__)
if sys.platform == "darwin":
from modules.devices import mac_devices
def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
return mac_devices.has_mps
def get_cuda_device_id():
return (
int(config.runtime_env_vars.device_id)
if config.runtime_env_vars.device_id is not None
and config.runtime_env_vars.device_id.isdigit()
else 0
) or torch.cuda.current_device()
def get_cuda_device_string():
if config.runtime_env_vars.device_id is not None:
return f"cuda:{config.runtime_env_vars.device_id}"
return "cuda"
def get_available_gpus() -> list[tuple[int, int]]:
"""
Get the list of available GPUs and their free memory.
:return: A list of tuples where each tuple contains (GPU index, free memory in bytes).
"""
available_gpus = []
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
free_memory = props.total_memory - torch.cuda.memory_reserved(i)
available_gpus.append((i, free_memory))
return available_gpus
def get_memory_available_gpus(min_memory=2048):
available_gpus = get_available_gpus()
memory_available_gpus = [
gpu for gpu, free_memory in available_gpus if free_memory > min_memory
]
return memory_available_gpus
def get_target_device_id_or_memory_available_gpu():
memory_available_gpus = get_memory_available_gpus()
device_id = get_cuda_device_id()
if device_id not in memory_available_gpus:
if len(memory_available_gpus) != 0:
logger.warning(
f"Device {device_id} is not available or does not have enough memory. will try to use {memory_available_gpus}"
)
config.runtime_env_vars.device_id = str(memory_available_gpus[0])
else:
logger.warning(
f"Device {device_id} is not available or does not have enough memory. Using CPU instead."
)
return "cpu"
return get_cuda_device_string()
def get_optimal_device_name():
if config.runtime_env_vars.use_cpu == "all":
return "cpu"
if torch.cuda.is_available():
return get_target_device_id_or_memory_available_gpu()
if has_mps():
return "mps"
return "cpu"
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task):
if (
task in config.runtime_env_vars.use_cpu
or "all" in config.runtime_env_vars.use_cpu
):
return cpu
return get_optimal_device()
def torch_gc():
try:
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if has_mps():
mac_devices.torch_mps_gc()
except Exception as e:
logger.error(f"Error in torch_gc", exc_info=True)
cpu: torch.device = torch.device("cpu")
device: torch.device = None
dtype: torch.dtype = torch.float32
dtype_dvae: torch.dtype = torch.float32
dtype_vocos: torch.dtype = torch.float32
dtype_gpt: torch.dtype = torch.float32
dtype_decoder: torch.dtype = torch.float32
def reset_device():
global device
global dtype
global dtype_dvae
global dtype_vocos
global dtype_gpt
global dtype_decoder
if config.runtime_env_vars.use_cpu is None:
config.runtime_env_vars.use_cpu = []
if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
logger.warning(
"Cannot use half precision with CPU, using full precision instead"
)
config.runtime_env_vars.no_half = True
if not config.runtime_env_vars.no_half:
dtype = torch.float16
dtype_dvae = torch.float16
dtype_vocos = torch.float16
dtype_gpt = torch.float16
dtype_decoder = torch.float16
logger.info("Using half precision: torch.float16")
else:
dtype = torch.float32
dtype_dvae = torch.float32
dtype_vocos = torch.float32
dtype_gpt = torch.float32
dtype_decoder = torch.float32
logger.info("Using full precision: torch.float32")
if "all" in config.runtime_env_vars.use_cpu:
device = cpu
else:
device = get_optimal_device()
logger.info(f"Using device: {device}")
@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)