glenn-jocher
commited on
Commit
β’
9b11f0c
1
Parent(s):
2d41e70
PyTorch Hub models default to CUDA:0 if available (#2472)
Browse files* PyTorch Hub models default to CUDA:0 if available
* device as string bug fix
- hubconf.py +3 -1
- utils/datasets.py +2 -2
- utils/general.py +1 -1
- utils/torch_utils.py +3 -3
hubconf.py
CHANGED
@@ -12,6 +12,7 @@ import torch
|
|
12 |
from models.yolo import Model
|
13 |
from utils.general import set_logging
|
14 |
from utils.google_utils import attempt_download
|
|
|
15 |
|
16 |
dependencies = ['torch', 'yaml']
|
17 |
set_logging()
|
@@ -43,7 +44,8 @@ def create(name, pretrained, channels, classes, autoshape):
|
|
43 |
model.names = ckpt['model'].names # set class names attribute
|
44 |
if autoshape:
|
45 |
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
46 |
-
|
|
|
47 |
|
48 |
except Exception as e:
|
49 |
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
|
|
12 |
from models.yolo import Model
|
13 |
from utils.general import set_logging
|
14 |
from utils.google_utils import attempt_download
|
15 |
+
from utils.torch_utils import select_device
|
16 |
|
17 |
dependencies = ['torch', 'yaml']
|
18 |
set_logging()
|
|
|
44 |
model.names = ckpt['model'].names # set class names attribute
|
45 |
if autoshape:
|
46 |
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
47 |
+
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
48 |
+
return model.to(device)
|
49 |
|
50 |
except Exception as e:
|
51 |
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
utils/datasets.py
CHANGED
@@ -385,7 +385,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
385 |
# Display cache
|
386 |
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
|
387 |
if exists:
|
388 |
-
d = f"Scanning '{cache_path}'
|
389 |
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
|
390 |
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
|
391 |
|
@@ -485,7 +485,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|
485 |
nc += 1
|
486 |
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
|
487 |
|
488 |
-
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}'
|
489 |
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
490 |
|
491 |
if nf == 0:
|
|
|
385 |
# Display cache
|
386 |
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
|
387 |
if exists:
|
388 |
+
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
389 |
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
|
390 |
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
|
391 |
|
|
|
485 |
nc += 1
|
486 |
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
|
487 |
|
488 |
+
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
|
489 |
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
|
490 |
|
491 |
if nf == 0:
|
utils/general.py
CHANGED
@@ -79,7 +79,7 @@ def check_git_status():
|
|
79 |
f"Use 'git pull' to update or 'git clone {url}' to download latest."
|
80 |
else:
|
81 |
s = f'up to date with {url} β
'
|
82 |
-
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s)
|
83 |
except Exception as e:
|
84 |
print(e)
|
85 |
|
|
|
79 |
f"Use 'git pull' to update or 'git clone {url}' to download latest."
|
80 |
else:
|
81 |
s = f'up to date with {url} β
'
|
82 |
+
print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
|
83 |
except Exception as e:
|
84 |
print(e)
|
85 |
|
utils/torch_utils.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
# PyTorch utils
|
2 |
-
|
3 |
import logging
|
4 |
import math
|
5 |
import os
|
|
|
6 |
import subprocess
|
7 |
import time
|
8 |
from contextlib import contextmanager
|
@@ -53,7 +53,7 @@ def git_describe():
|
|
53 |
|
54 |
def select_device(device='', batch_size=None):
|
55 |
# device = 'cpu' or '0' or '0,1,2,3'
|
56 |
-
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
|
57 |
cpu = device.lower() == 'cpu'
|
58 |
if cpu:
|
59 |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
@@ -73,7 +73,7 @@ def select_device(device='', batch_size=None):
|
|
73 |
else:
|
74 |
s += 'CPU\n'
|
75 |
|
76 |
-
logger.info(s) #
|
77 |
return torch.device('cuda:0' if cuda else 'cpu')
|
78 |
|
79 |
|
|
|
1 |
# PyTorch utils
|
|
|
2 |
import logging
|
3 |
import math
|
4 |
import os
|
5 |
+
import platform
|
6 |
import subprocess
|
7 |
import time
|
8 |
from contextlib import contextmanager
|
|
|
53 |
|
54 |
def select_device(device='', batch_size=None):
|
55 |
# device = 'cpu' or '0' or '0,1,2,3'
|
56 |
+
s = f'YOLOv5 π {git_describe()} torch {torch.__version__} ' # string
|
57 |
cpu = device.lower() == 'cpu'
|
58 |
if cpu:
|
59 |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
|
|
73 |
else:
|
74 |
s += 'CPU\n'
|
75 |
|
76 |
+
logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
|
77 |
return torch.device('cuda:0' if cuda else 'cpu')
|
78 |
|
79 |
|