File size: 1,895 Bytes
45099b6 0f782f0 45099b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
"""Creates a specified YOLOv5 model
Arguments:
name (str): name of model, i.e. 'yolov5s'
pretrained (bool): load pretrained weights into the model
channels (int): number of input channels
classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen
device (str, torch.device, None): device to use for model parameters
Returns:
YOLOv5 pytorch model
"""
from pathlib import Path
from models.experimental import attempt_load
from models.yolo import Model
from yolo_utils.torch_utils import select_device
file = Path(__file__).absolute()
save_dir = Path("") if str(name).endswith(".pt") else file.parent
path = (save_dir / name).with_suffix(".pt") # checkpoint path
try:
device = "cpu"
if pretrained and channels == 3 and classes == 80:
model = attempt_load(path, map_location=device) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / "models").rglob(f"{name}.yaml"))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
return model.to(device)
except Exception as e:
help_url = "https://github.com/ultralytics/yolov5/issues/36"
s = "Cache may be out of date, try `force_reload=True`. See %s for help." % help_url
raise Exception(s) from e
def custom(path="path/to/model.pt", autoshape=True, verbose=True, device=None):
# YOLOv5 custom or local model
return _create(path, autoshape=autoshape, verbose=verbose, device=device)
|