glenn-jocher commited on
Commit
f7a923b
1 Parent(s): 87ca35b

Simplified PyTorch hub for custom models (#1677)

Browse files
Files changed (1) hide show
  1. hubconf.py +6 -7
hubconf.py CHANGED
@@ -106,19 +106,18 @@ def yolov5x(pretrained=False, channels=3, classes=80):
106
  return create('yolov5x', pretrained, channels, classes)
107
 
108
 
109
- def custom(model='path/to/model.pt'):
110
  """YOLOv5-custom model from https://github.com/ultralytics/yolov5
111
 
112
- Arguments (3 format options):
113
- model (str): 'path/to/model.pt'
114
- model (dict): torch.load('path/to/model.pt')
115
- model (nn.Module): 'torch.load('path/to/model.pt')['model']
116
 
117
  Returns:
118
  pytorch model
119
  """
120
- if isinstance(model, str):
121
- model = torch.load(model) # load checkpoint
122
  if isinstance(model, dict):
123
  model = model['model'] # load model
124
 
 
106
  return create('yolov5x', pretrained, channels, classes)
107
 
108
 
109
+ def custom(path_or_model='path/to/model.pt'):
110
  """YOLOv5-custom model from https://github.com/ultralytics/yolov5
111
 
112
+ Arguments (3 options):
113
+ path_or_model (str): 'path/to/model.pt'
114
+ path_or_model (dict): torch.load('path/to/model.pt')
115
+ path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
116
 
117
  Returns:
118
  pytorch model
119
  """
120
+ model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
 
121
  if isinstance(model, dict):
122
  model = model['model'] # load model
123