Spaces:
Runtime error
Runtime error
import torch | |
import logging | |
from packaging import version | |
import torch.backends | |
import torch.backends.mps | |
logger = logging.getLogger(__name__) | |
def check_for_mps() -> bool: | |
if version.parse(torch.__version__) <= version.parse("2.0.1"): | |
if not getattr(torch, "has_mps", False): | |
return False | |
try: | |
torch.zeros(1).to(torch.device("mps")) | |
return True | |
except Exception: | |
return False | |
else: | |
try: | |
return torch.backends.mps.is_available() and torch.backends.mps.is_built() | |
except: | |
logger.warning("MPS garbage collection failed", exc_info=True) | |
return False | |
has_mps = check_for_mps() | |
def torch_mps_gc() -> None: | |
try: | |
from torch.mps import empty_cache | |
empty_cache() | |
except Exception: | |
logger.warning("MPS garbage collection failed", exc_info=True) | |
if __name__ == "__main__": | |
print(torch.__version__) | |
print(has_mps) | |
torch_mps_gc() | |