table_detection_3 / engine.py
nguyenp99's picture
Update engine.py
dbc73de verified
raw
history blame
2.23 kB
import plasma.functional as f
import plasma.training.utils as utils
import plasma.huggingface as hf
import plasma.meta as meta
from .config import Config
from .model import YOLORunner
from .preprocesses import Preprocessor
from .post_processes import TableRestore
from ultralytics import YOLO
from .stamp_processing.detector import StampDetector
from .stamp_processing.callback import remove_potiential_table_fp, remove_box_by_idc
class Engine(f.Pipe):
def __init__(self, cfg:Config=None, table_checkpoint=None, line_checkpoint=None, verbose=True):
super().__init__(
cfg = cfg or Config(),
standard_width = cfg.STANDARD_WIDTH,
preprocessor = Preprocessor(cfg.STANDARD_WIDTH, cfg.STANDARD_INTERPOLATION),
post_processor = TableRestore(),
table_detector = self._build_table_detector(cfg, table_checkpoint),
stamp_detector = StampDetector(model_path=cfg.STAMP_DETECTION_CHECKPOINT, device=cfg.DEVICE)
)
self.stamp_detector.model.to(cfg.DEVICE)
def run(self, image):
# ratio = 1.0*self.standard_width/image.shape[1]
# image = self.preprocessor(image)
tables = self.table_detector(image)
# tables = self.post_processor(tables, ratio)
return tables
def _build_table_detector(self, cfg: Config, checkpoint=None, verbose=False):
if checkpoint is None:
checkpoint = hf.download_file(cfg.TABLE_DETECTION_CHECKPOINT)
model = YOLO(checkpoint).to(cfg.DEVICE)
if cfg.USE_STAMP_DETECTION:
model.add_callback("on_predict_postprocess_end", self._stamp_detection_callback)
return YOLORunner(model, cfg.YOLO_IMAGE_SIZE, cfg.CONF_THRESHOLD, cfg.HEIGHT_EXPAND_RATIO, cfg.DEVICE, verbose)
def _stamp_detection_callback(self, predictor):
assert len(predictor.results) == 1, 'Only support batch size 1'
preds = predictor.results[0]
remove_idc = remove_potiential_table_fp(self.stamp_detector, preds.orig_img, preds.boxes.xyxy, self.cfg.STAMP_REMOVING_IOU_THRESHOLD)
stamp_removed_preds = remove_box_by_idc(preds.boxes, remove_idc)
predictor.results[0].boxes = stamp_removed_preds