Cédric Colas
initial commit
e775f6d
raw
history blame
1.4 kB
# For all things related to devices
#### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS ####
import torch
TORCH_CPU_DEVICE = torch.device("cpu")
if(torch.cuda.device_count() > 0):
TORCH_CUDA_DEVICE = torch.device("cuda")
else:
print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
print("")
TORCH_CUDA_DEVICE = None
USE_CUDA = True
# use_cuda
def use_cuda(cuda_bool):
"""
----------
Author: Damon Gwinn
----------
Sets whether to use CUDA (if available), or use the CPU (not recommended)
----------
"""
global USE_CUDA
USE_CUDA = cuda_bool
# get_device
def get_device():
"""
----------
Author: Damon Gwinn
----------
Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise.
----------
"""
if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)):
return TORCH_CPU_DEVICE
else:
return TORCH_CUDA_DEVICE
# cuda_device
def cuda_device():
"""
----------
Author: Damon Gwinn
----------
Grabs the cuda device (may be None if CUDA is not available)
----------
"""
return TORCH_CUDA_DEVICE
# cpu_device
def cpu_device():
"""
----------
Author: Damon Gwinn
----------
Grabs the cpu device
----------
"""
return TORCH_CPU_DEVICE