glenn-jocher
commited on
Commit
•
d08575e
1
Parent(s):
9b91db6
PyTorch Hub load directly when possible (#2986)
Browse files- hubconf.py +23 -19
hubconf.py
CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
|
|
9 |
|
10 |
import torch
|
11 |
|
12 |
-
from models.yolo import Model
|
13 |
from utils.general import check_requirements, set_logging
|
14 |
from utils.google_utils import attempt_download
|
15 |
from utils.torch_utils import select_device
|
@@ -26,33 +26,37 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
|
|
26 |
pretrained (bool): load pretrained weights into the model
|
27 |
channels (int): number of input channels
|
28 |
classes (int): number of model classes
|
|
|
|
|
29 |
|
30 |
Returns:
|
31 |
-
pytorch model
|
32 |
"""
|
|
|
|
|
33 |
try:
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
51 |
return model.to(device)
|
52 |
|
53 |
except Exception as e:
|
54 |
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
55 |
-
s = 'Cache
|
56 |
raise Exception(s) from e
|
57 |
|
58 |
|
|
|
9 |
|
10 |
import torch
|
11 |
|
12 |
+
from models.yolo import Model, attempt_load
|
13 |
from utils.general import check_requirements, set_logging
|
14 |
from utils.google_utils import attempt_download
|
15 |
from utils.torch_utils import select_device
|
|
|
26 |
pretrained (bool): load pretrained weights into the model
|
27 |
channels (int): number of input channels
|
28 |
classes (int): number of model classes
|
29 |
+
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
|
30 |
+
verbose (bool): print all information to screen
|
31 |
|
32 |
Returns:
|
33 |
+
YOLOv5 pytorch model
|
34 |
"""
|
35 |
+
set_logging(verbose=verbose)
|
36 |
+
fname = f'{name}.pt' # checkpoint filename
|
37 |
try:
|
38 |
+
if pretrained and channels == 3 and classes == 80:
|
39 |
+
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
|
40 |
+
else:
|
41 |
+
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
42 |
+
model = Model(cfg, channels, classes) # create model
|
43 |
+
if pretrained:
|
44 |
+
attempt_download(fname) # download if not found locally
|
45 |
+
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
46 |
+
msd = model.state_dict() # model state_dict
|
47 |
+
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
48 |
+
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
49 |
+
model.load_state_dict(csd, strict=False) # load
|
50 |
+
if len(ckpt['model'].names) == classes:
|
51 |
+
model.names = ckpt['model'].names # set class names attribute
|
52 |
+
if autoshape:
|
53 |
+
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
54 |
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
55 |
return model.to(device)
|
56 |
|
57 |
except Exception as e:
|
58 |
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
59 |
+
s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
|
60 |
raise Exception(s) from e
|
61 |
|
62 |
|