|
import torch |
|
|
|
|
|
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None): |
|
"""Creates a specified YOLOv5 model |
|
Arguments: |
|
name (str): name of model, i.e. 'yolov5s' |
|
pretrained (bool): load pretrained weights into the model |
|
channels (int): number of input channels |
|
classes (int): number of model classes |
|
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model |
|
verbose (bool): print all information to screen |
|
device (str, torch.device, None): device to use for model parameters |
|
Returns: |
|
YOLOv5 pytorch model |
|
""" |
|
from pathlib import Path |
|
|
|
from models.experimental import attempt_load |
|
from models.yolo import Model |
|
from yolo_utils.torch_utils import select_device |
|
|
|
file = Path(__file__).absolute() |
|
|
|
save_dir = Path("") if str(name).endswith(".pt") else file.parent |
|
path = (save_dir / name).with_suffix(".pt") |
|
try: |
|
device = "cpu" |
|
|
|
if pretrained and channels == 3 and classes == 80: |
|
model = attempt_load(path, map_location=device) |
|
else: |
|
cfg = list((Path(__file__).parent / "models").rglob(f"{name}.yaml"))[0] |
|
model = Model(cfg, channels, classes) |
|
if autoshape: |
|
model = model.autoshape() |
|
return model.to(device) |
|
|
|
except Exception as e: |
|
help_url = "https://github.com/ultralytics/yolov5/issues/36" |
|
s = "Cache may be out of date, try `force_reload=True`. See %s for help." % help_url |
|
raise Exception(s) from e |
|
|
|
|
|
def custom(path="path/to/model.pt", autoshape=True, verbose=True, device=None): |
|
|
|
return _create(path, autoshape=autoshape, verbose=verbose, device=device) |
|
|