|
from super_gradients.training import models |
|
import supervision as sv |
|
import torch |
|
import cv2 |
|
|
|
|
|
|
|
|
|
|
|
def predict(img, model): |
|
SOURCE_IMAGE_PATH = img |
|
image = cv2.imread(SOURCE_IMAGE_PATH) |
|
result = list(model.predict(image, conf=0.70))[0] |
|
detections = sv.Detections( |
|
xyxy=result.prediction.bboxes_xyxy, |
|
confidence=result.prediction.confidence, |
|
class_id=result.prediction.labels.astype(int) |
|
) |
|
|
|
box_annotator = sv.BoxAnnotator() |
|
|
|
labels = [ |
|
f"{result.class_names[class_id]} {confidence:0.2f}" |
|
for _, _, confidence, class_id, _ |
|
in detections |
|
] |
|
|
|
annotated_frame = box_annotator.annotate( |
|
scene=image.copy(), |
|
detections=detections, |
|
labels=labels |
|
) |
|
|
|
|
|
return annotated_frame, labels |
|
|
|
|
|
def setup(): |
|
DEVICE = 'cuda' if torch.cuda.is_available() else "cpu" |
|
MODEL_ARCH = 'yolo_nas_l' |
|
|
|
|
|
model = models.get( |
|
MODEL_ARCH, |
|
num_classes=2, |
|
checkpoint_path="model.pth" |
|
).to(DEVICE) |
|
return model |