r3gm commited on
Commit
3e3db38
1 Parent(s): 3ace12e

Update src/mdx.py

Browse files
Files changed (1) hide show
  1. src/mdx.py +2 -1
src/mdx.py CHANGED
@@ -65,7 +65,8 @@ class MDX:
65
  def __init__(self, model_path: str, params: MDXModel, processor=DEFAULT_PROCESSOR):
66
 
67
  # Set the device and the provider (CPU or CUDA)
68
- self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
 
69
  self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
70
 
71
  self.model = params
 
65
  def __init__(self, model_path: str, params: MDXModel, processor=DEFAULT_PROCESSOR):
66
 
67
  # Set the device and the provider (CPU or CUDA)
68
+ #self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
69
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
70
  self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
71
 
72
  self.model = params