johnohagan glenn-jocher commited on
Commit
61047a2
1 Parent(s): 33202b7

Save PyTorch Hub models to `/root/hub/cache/dir` (#3904)

Browse files

* Create hubconf.py

* Add save_dir variable

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. hubconf.py +8 -7
hubconf.py CHANGED
@@ -4,9 +4,12 @@ Usage:
4
  import torch
5
  model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
6
  """
 
7
 
8
  import torch
9
 
 
 
10
 
11
  def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
12
  """Creates a specified YOLOv5 model
@@ -23,28 +26,26 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
23
  Returns:
24
  YOLOv5 pytorch model
25
  """
26
- from pathlib import Path
27
-
28
  from models.yolo import Model, attempt_load
29
  from utils.general import check_requirements, set_logging
30
  from utils.google_utils import attempt_download
31
  from utils.torch_utils import select_device
32
 
33
- check_requirements(requirements=Path(__file__).parent / 'requirements.txt',
34
- exclude=('tensorboard', 'thop', 'opencv-python'))
35
  set_logging(verbose=verbose)
36
 
37
- fname = Path(name).with_suffix('.pt') # checkpoint filename
 
38
  try:
39
  device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
40
 
41
  if pretrained and channels == 3 and classes == 80:
42
- model = attempt_load(fname, map_location=device) # download/load FP32 model
43
  else:
44
  cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
45
  model = Model(cfg, channels, classes) # create model
46
  if pretrained:
47
- ckpt = torch.load(attempt_download(fname), map_location=device) # load
48
  msd = model.state_dict() # model state_dict
49
  csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
50
  csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
 
4
  import torch
5
  model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
6
  """
7
+ from pathlib import Path
8
 
9
  import torch
10
 
11
+ FILE = Path(__file__).absolute()
12
+
13
 
14
  def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
15
  """Creates a specified YOLOv5 model
 
26
  Returns:
27
  YOLOv5 pytorch model
28
  """
 
 
29
  from models.yolo import Model, attempt_load
30
  from utils.general import check_requirements, set_logging
31
  from utils.google_utils import attempt_download
32
  from utils.torch_utils import select_device
33
 
34
+ check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=('tensorboard', 'thop', 'opencv-python'))
 
35
  set_logging(verbose=verbose)
36
 
37
+ save_dir = Path('') if str(name).endswith('.pt') else FILE.parent
38
+ path = (save_dir / name).with_suffix('.pt') # checkpoint path
39
  try:
40
  device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
41
 
42
  if pretrained and channels == 3 and classes == 80:
43
+ model = attempt_load(path, map_location=device) # download/load FP32 model
44
  else:
45
  cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
46
  model = Model(cfg, channels, classes) # create model
47
  if pretrained:
48
+ ckpt = torch.load(attempt_download(path), map_location=device) # load
49
  msd = model.state_dict() # model state_dict
50
  csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
51
  csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter