File size: 1,539 Bytes
319d3b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
from typing import Dict, Union
import numpy as np
import PIL
import torch
from mivolo.structures import PersonAndFaceResult
from ultralytics import YOLO
# from ultralytics.yolo.engine.results import Results
# because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
class Detector:
def __init__(
self,
weights: str,
device: str = "cpu",
half: bool = True,
verbose: bool = False,
conf_thresh: float = 0.4,
iou_thresh: float = 0.7,
):
self.yolo = YOLO(weights)
self.yolo.fuse()
self.device = torch.device(device)
self.half = half and self.device.type != "cpu"
if self.half:
self.yolo.model = self.yolo.model.half()
self.detector_names: Dict[int, str] = self.yolo.model.names
# init yolo.predictor
self.detector_kwargs = {
"conf": conf_thresh, "iou": iou_thresh, "half": self.half, "verbose": verbose}
# self.yolo.predict(**self.detector_kwargs)
def predict(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
results = self.yolo.predict(image, **self.detector_kwargs)[0]
return PersonAndFaceResult(results)
def track(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
results = self.yolo.track(
image, persist=True, **self.detector_kwargs)[0]
return PersonAndFaceResult(results)
|