Commit
•
61047a2
1
Parent(s):
33202b7
Save PyTorch Hub models to `/root/hub/cache/dir` (#3904)
Browse files* Create hubconf.py
* Add save_dir variable
Co-authored-by: Glenn Jocher <[email protected]>
- hubconf.py +8 -7
hubconf.py
CHANGED
@@ -4,9 +4,12 @@ Usage:
|
|
4 |
import torch
|
5 |
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
|
6 |
"""
|
|
|
7 |
|
8 |
import torch
|
9 |
|
|
|
|
|
10 |
|
11 |
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
12 |
"""Creates a specified YOLOv5 model
|
@@ -23,28 +26,26 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
23 |
Returns:
|
24 |
YOLOv5 pytorch model
|
25 |
"""
|
26 |
-
from pathlib import Path
|
27 |
-
|
28 |
from models.yolo import Model, attempt_load
|
29 |
from utils.general import check_requirements, set_logging
|
30 |
from utils.google_utils import attempt_download
|
31 |
from utils.torch_utils import select_device
|
32 |
|
33 |
-
check_requirements(requirements=
|
34 |
-
exclude=('tensorboard', 'thop', 'opencv-python'))
|
35 |
set_logging(verbose=verbose)
|
36 |
|
37 |
-
|
|
|
38 |
try:
|
39 |
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
|
40 |
|
41 |
if pretrained and channels == 3 and classes == 80:
|
42 |
-
model = attempt_load(
|
43 |
else:
|
44 |
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
45 |
model = Model(cfg, channels, classes) # create model
|
46 |
if pretrained:
|
47 |
-
ckpt = torch.load(attempt_download(
|
48 |
msd = model.state_dict() # model state_dict
|
49 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
50 |
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
|
|
4 |
import torch
|
5 |
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
|
6 |
"""
|
7 |
+
from pathlib import Path
|
8 |
|
9 |
import torch
|
10 |
|
11 |
+
FILE = Path(__file__).absolute()
|
12 |
+
|
13 |
|
14 |
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
15 |
"""Creates a specified YOLOv5 model
|
|
|
26 |
Returns:
|
27 |
YOLOv5 pytorch model
|
28 |
"""
|
|
|
|
|
29 |
from models.yolo import Model, attempt_load
|
30 |
from utils.general import check_requirements, set_logging
|
31 |
from utils.google_utils import attempt_download
|
32 |
from utils.torch_utils import select_device
|
33 |
|
34 |
+
check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=('tensorboard', 'thop', 'opencv-python'))
|
|
|
35 |
set_logging(verbose=verbose)
|
36 |
|
37 |
+
save_dir = Path('') if str(name).endswith('.pt') else FILE.parent
|
38 |
+
path = (save_dir / name).with_suffix('.pt') # checkpoint path
|
39 |
try:
|
40 |
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
|
41 |
|
42 |
if pretrained and channels == 3 and classes == 80:
|
43 |
+
model = attempt_load(path, map_location=device) # download/load FP32 model
|
44 |
else:
|
45 |
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
46 |
model = Model(cfg, channels, classes) # create model
|
47 |
if pretrained:
|
48 |
+
ckpt = torch.load(attempt_download(path), map_location=device) # load
|
49 |
msd = model.state_dict() # model state_dict
|
50 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
51 |
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|