import plasma.functional as f import plasma.huggingface as hf import plasma.meta as meta from .config import Config from .model import YOLORunner from .preprocesses import Preprocessor from .post_processes import TableRestore from ultralytics import YOLO from .stamp_processing.detector import StampDetector from .stamp_processing.callback import remove_potiential_table_fp, remove_box_by_idc class Engine(f.Pipe): def __init__(self, cfg:Config=None, table_checkpoint=None, line_checkpoint=None, verbose=True): super().__init__( cfg = cfg or Config(), standard_width = cfg.STANDARD_WIDTH, preprocessor = Preprocessor(cfg.STANDARD_WIDTH, cfg.STANDARD_INTERPOLATION), post_processor = TableRestore(), table_detector = self._build_table_detector(cfg, table_checkpoint), stamp_detector = StampDetector(model_path=cfg.STAMP_DETECTION_CHECKPOINT, device=cfg.DEVICE) ) self.stamp_detector.model.to(cfg.DEVICE) def run(self, image): # ratio = 1.0*self.standard_width/image.shape[1] # image = self.preprocessor(image) tables = self.table_detector(image) # tables = self.post_processor(tables, ratio) return tables def _build_table_detector(self, cfg: Config, checkpoint=None, verbose=False): if checkpoint is None: checkpoint = hf.download_file(cfg.TABLE_DETECTION_CHECKPOINT) model = YOLO(checkpoint).to(cfg.DEVICE) if cfg.USE_STAMP_DETECTION: model.add_callback("on_predict_postprocess_end", self._stamp_detection_callback) return YOLORunner(model, cfg.YOLO_IMAGE_SIZE, cfg.CONF_THRESHOLD, cfg.HEIGHT_EXPAND_RATIO, cfg.DEVICE, verbose) def _stamp_detection_callback(self, predictor): assert len(predictor.results) == 1, 'Only support batch size 1' preds = predictor.results[0] remove_idc = remove_potiential_table_fp(self.stamp_detector, preds.orig_img, preds.boxes.xyxy, self.cfg.STAMP_REMOVING_IOU_THRESHOLD) stamp_removed_preds = remove_box_by_idc(preds.boxes, remove_idc) predictor.results[0].boxes = stamp_removed_preds