Spaces:
Runtime error
Runtime error
# Ultralytics YOLO π, AGPL-3.0 license | |
""" | |
YOLO-NAS model interface. | |
Usage - Predict: | |
from ultralytics import NAS | |
model = NAS('yolo_nas_s') | |
results = model.predict('ultralytics/assets/bus.jpg') | |
""" | |
from pathlib import Path | |
import torch | |
from ultralytics.engine.model import Model | |
from ultralytics.utils.torch_utils import model_info, smart_inference_mode | |
from .predict import NASPredictor | |
from .val import NASValidator | |
class NAS(Model): | |
def __init__(self, model='yolo_nas_s.pt') -> None: | |
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.' | |
super().__init__(model, task='detect') | |
def _load(self, weights: str, task: str): | |
# Load or create new NAS model | |
import super_gradients | |
suffix = Path(weights).suffix | |
if suffix == '.pt': | |
self.model = torch.load(weights) | |
elif suffix == '': | |
self.model = super_gradients.training.models.get(weights, pretrained_weights='coco') | |
# Standardize model | |
self.model.fuse = lambda verbose=True: self.model | |
self.model.stride = torch.tensor([32]) | |
self.model.names = dict(enumerate(self.model._class_names)) | |
self.model.is_fused = lambda: False # for info() | |
self.model.yaml = {} # for info() | |
self.model.pt_path = weights # for export() | |
self.model.task = 'detect' # for export() | |
def info(self, detailed=False, verbose=True): | |
""" | |
Logs model info. | |
Args: | |
detailed (bool): Show detailed information about model. | |
verbose (bool): Controls verbosity. | |
""" | |
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) | |
def task_map(self): | |
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}} | |