File size: 5,533 Bytes
2bcc89d
a814720
 
 
2bcc89d
a814720
 
1c8464e
0dec8ff
a814720
 
 
2bcc89d
d5b6416
9b11f0c
a814720
2aaaec3
8ace1b1
7f16406
 
a814720
14b0abe
a814720
 
 
 
 
 
 
 
 
 
 
1c8464e
6e3c3b6
 
 
7f16406
 
 
 
6e3c3b6
 
7f16406
 
14b0abe
 
9b11f0c
 
84a9466
6e3c3b6
 
0fda95a
84a9466
a814720
 
569757e
a814720
 
 
 
 
 
 
 
 
 
14b0abe
a814720
 
569757e
a814720
 
 
 
 
 
 
 
 
 
14b0abe
a814720
 
569757e
a814720
 
 
 
 
 
 
 
 
 
14b0abe
a814720
 
569757e
a814720
 
 
 
 
 
 
 
 
 
14b0abe
7f16406
 
14b0abe
87ca35b
 
f7a923b
 
 
 
87ca35b
 
 
 
f7a923b
87ca35b
fab5085
87ca35b
 
 
 
14b0abe
87ca35b
 
7f16406
14b0abe
e92245a
c8c5ef3
 
db28ce6
311de00
 
db28ce6
 
 
 
 
 
f542926
9a3da79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""File for accessing YOLOv5 models via PyTorch Hub https://pytorch.org/hub/ultralytics_yolov5/

Usage:
    import torch
    model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
"""

from pathlib import Path

import torch

from models.yolo import Model
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device

dependencies = ['torch', 'yaml']
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('pycocotools', 'thop'))
set_logging()


def create(name, pretrained, channels, classes, autoshape):
    """Creates a specified YOLOv5 model

    Arguments:
        name (str): name of model, i.e. 'yolov5s'
        pretrained (bool): load pretrained weights into the model
        channels (int): number of input channels
        classes (int): number of model classes

    Returns:
        pytorch model
    """
    config = Path(__file__).parent / 'models' / f'{name}.yaml'  # model.yaml path
    try:
        model = Model(config, channels, classes)
        if pretrained:
            fname = f'{name}.pt'  # checkpoint filename
            attempt_download(fname)  # download if not found locally
            ckpt = torch.load(fname, map_location=torch.device('cpu'))  # load
            state_dict = ckpt['model'].float().state_dict()  # to FP32
            state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape}  # filter
            model.load_state_dict(state_dict, strict=False)  # load
            if len(ckpt['model'].names) == classes:
                model.names = ckpt['model'].names  # set class names attribute
            if autoshape:
                model = model.autoshape()  # for file/URI/PIL/cv2/np inputs and NMS
        device = select_device('0' if torch.cuda.is_available() else 'cpu')  # default to GPU if available
        return model.to(device)

    except Exception as e:
        help_url = 'https://github.com/ultralytics/yolov5/issues/36'
        s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
        raise Exception(s) from e


def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True):
    """YOLOv5-small model from https://github.com/ultralytics/yolov5

    Arguments:
        pretrained (bool): load pretrained weights into the model, default=False
        channels (int): number of input channels, default=3
        classes (int): number of model classes, default=80

    Returns:
        pytorch model
    """
    return create('yolov5s', pretrained, channels, classes, autoshape)


def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True):
    """YOLOv5-medium model from https://github.com/ultralytics/yolov5

    Arguments:
        pretrained (bool): load pretrained weights into the model, default=False
        channels (int): number of input channels, default=3
        classes (int): number of model classes, default=80

    Returns:
        pytorch model
    """
    return create('yolov5m', pretrained, channels, classes, autoshape)


def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True):
    """YOLOv5-large model from https://github.com/ultralytics/yolov5

    Arguments:
        pretrained (bool): load pretrained weights into the model, default=False
        channels (int): number of input channels, default=3
        classes (int): number of model classes, default=80

    Returns:
        pytorch model
    """
    return create('yolov5l', pretrained, channels, classes, autoshape)


def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True):
    """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5

    Arguments:
        pretrained (bool): load pretrained weights into the model, default=False
        channels (int): number of input channels, default=3
        classes (int): number of model classes, default=80

    Returns:
        pytorch model
    """
    return create('yolov5x', pretrained, channels, classes, autoshape)


def custom(path_or_model='path/to/model.pt', autoshape=True):
    """YOLOv5-custom model from https://github.com/ultralytics/yolov5

    Arguments (3 options):
        path_or_model (str): 'path/to/model.pt'
        path_or_model (dict): torch.load('path/to/model.pt')
        path_or_model (nn.Module): torch.load('path/to/model.pt')['model']

    Returns:
        pytorch model
    """
    model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model  # load checkpoint
    if isinstance(model, dict):
        model = model['ema' if model.get('ema') else 'model']  # load model

    hub_model = Model(model.yaml).to(next(model.parameters()).device)  # create
    hub_model.load_state_dict(model.float().state_dict())  # load state_dict
    hub_model.names = model.names  # class names
    return hub_model.autoshape() if autoshape else hub_model


if __name__ == '__main__':
    model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True)  # pretrained example
    # model = custom(path_or_model='path/to/model.pt')  # custom example

    # Verify inference
    import numpy as np
    from PIL import Image

    imgs = [Image.open('data/images/bus.jpg'),  # PIL
            'data/images/zidane.jpg',  # filename
            'https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg',  # URI
            np.zeros((640, 480, 3))]  # numpy

    results = model(imgs)  # batched inference
    results.print()
    results.save()