File size: 1,955 Bytes
07b7ae3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""

YOLO-NAS model interface.



Usage - Predict:

    from ultralytics import NAS



    model = NAS('yolo_nas_s')

    results = model.predict('ultralytics/assets/bus.jpg')

"""

from pathlib import Path

import torch

from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info, smart_inference_mode

from .predict import NASPredictor
from .val import NASValidator


class NAS(Model):

    def __init__(self, model='yolo_nas_s.pt') -> None:
        assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
        super().__init__(model, task='detect')

    @smart_inference_mode()
    def _load(self, weights: str, task: str):
        # Load or create new NAS model
        import super_gradients
        suffix = Path(weights).suffix
        if suffix == '.pt':
            self.model = torch.load(weights)
        elif suffix == '':
            self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
        # Standardize model
        self.model.fuse = lambda verbose=True: self.model
        self.model.stride = torch.tensor([32])
        self.model.names = dict(enumerate(self.model._class_names))
        self.model.is_fused = lambda: False  # for info()
        self.model.yaml = {}  # for info()
        self.model.pt_path = weights  # for export()
        self.model.task = 'detect'  # for export()

    def info(self, detailed=False, verbose=True):
        """

        Logs model info.



        Args:

            detailed (bool): Show detailed information about model.

            verbose (bool): Controls verbosity.

        """
        return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)

    @property
    def task_map(self):
        return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}