glenn-jocher commited on
Commit
9b11f0c
β€’
1 Parent(s): 2d41e70

PyTorch Hub models default to CUDA:0 if available (#2472)

Browse files

* PyTorch Hub models default to CUDA:0 if available

* device as string bug fix

Files changed (4) hide show
  1. hubconf.py +3 -1
  2. utils/datasets.py +2 -2
  3. utils/general.py +1 -1
  4. utils/torch_utils.py +3 -3
hubconf.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  from models.yolo import Model
13
  from utils.general import set_logging
14
  from utils.google_utils import attempt_download
 
15
 
16
  dependencies = ['torch', 'yaml']
17
  set_logging()
@@ -43,7 +44,8 @@ def create(name, pretrained, channels, classes, autoshape):
43
  model.names = ckpt['model'].names # set class names attribute
44
  if autoshape:
45
  model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
46
- return model
 
47
 
48
  except Exception as e:
49
  help_url = 'https://github.com/ultralytics/yolov5/issues/36'
 
12
  from models.yolo import Model
13
  from utils.general import set_logging
14
  from utils.google_utils import attempt_download
15
+ from utils.torch_utils import select_device
16
 
17
  dependencies = ['torch', 'yaml']
18
  set_logging()
 
44
  model.names = ckpt['model'].names # set class names attribute
45
  if autoshape:
46
  model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
47
+ device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
48
+ return model.to(device)
49
 
50
  except Exception as e:
51
  help_url = 'https://github.com/ultralytics/yolov5/issues/36'
utils/datasets.py CHANGED
@@ -385,7 +385,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
385
  # Display cache
386
  nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
387
  if exists:
388
- d = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
389
  tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
390
  assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
391
 
@@ -485,7 +485,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
485
  nc += 1
486
  print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
487
 
488
- pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \
489
  f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
490
 
491
  if nf == 0:
 
385
  # Display cache
386
  nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
387
  if exists:
388
+ d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
389
  tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
390
  assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
391
 
 
485
  nc += 1
486
  print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
487
 
488
+ pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
489
  f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
490
 
491
  if nf == 0:
utils/general.py CHANGED
@@ -79,7 +79,7 @@ def check_git_status():
79
  f"Use 'git pull' to update or 'git clone {url}' to download latest."
80
  else:
81
  s = f'up to date with {url} βœ…'
82
- print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s)
83
  except Exception as e:
84
  print(e)
85
 
 
79
  f"Use 'git pull' to update or 'git clone {url}' to download latest."
80
  else:
81
  s = f'up to date with {url} βœ…'
82
+ print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
83
  except Exception as e:
84
  print(e)
85
 
utils/torch_utils.py CHANGED
@@ -1,8 +1,8 @@
1
  # PyTorch utils
2
-
3
  import logging
4
  import math
5
  import os
 
6
  import subprocess
7
  import time
8
  from contextlib import contextmanager
@@ -53,7 +53,7 @@ def git_describe():
53
 
54
  def select_device(device='', batch_size=None):
55
  # device = 'cpu' or '0' or '0,1,2,3'
56
- s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
57
  cpu = device.lower() == 'cpu'
58
  if cpu:
59
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
@@ -73,7 +73,7 @@ def select_device(device='', batch_size=None):
73
  else:
74
  s += 'CPU\n'
75
 
76
- logger.info(s) # skip a line
77
  return torch.device('cuda:0' if cuda else 'cpu')
78
 
79
 
 
1
  # PyTorch utils
 
2
  import logging
3
  import math
4
  import os
5
+ import platform
6
  import subprocess
7
  import time
8
  from contextlib import contextmanager
 
53
 
54
  def select_device(device='', batch_size=None):
55
  # device = 'cpu' or '0' or '0,1,2,3'
56
+ s = f'YOLOv5 πŸš€ {git_describe()} torch {torch.__version__} ' # string
57
  cpu = device.lower() == 'cpu'
58
  if cpu:
59
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
 
73
  else:
74
  s += 'CPU\n'
75
 
76
+ logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
77
  return torch.device('cuda:0' if cuda else 'cpu')
78
 
79