table_detection_3 / model.py
nguyenp99's picture
Create model.py
d19b3e3 verified
raw
history blame
855 Bytes
import cv2
import torch
import plasma.functional as f
from ultralytics import YOLO
class YOLORunner(f.Pipe):
def __init__(self, model: YOLO, image_size, conf_thrs, height_ratio, device, verbose):
super().__init__(image_size=image_size, conf_thrs=conf_thrs, device=device, height_ratio=height_ratio, verbose=verbose)
self.model = model
def run(self, image):
results = self.model.predict(source=image, imgsz=self.image_size, conf=self.conf_thrs, verbose=self.verbose)
results = results[0].boxes.xyxy.cpu().numpy().astype(int)
# expand the height of the boxes
height = results[:, 3] - results[:, 1]
results[:, 1] = (results[:, 1] - height * self.height_ratio).clip(0)
results[:, 3] = (results[:, 3] + height * self.height_ratio).clip(0, image.shape[0])
return results