|
import plasma.functional as f |
|
import plasma.huggingface as hf |
|
import plasma.meta as meta |
|
|
|
from .config import Config |
|
from .model import YOLORunner |
|
from .preprocesses import Preprocessor |
|
from .post_processes import TableRestore |
|
|
|
from ultralytics import YOLO |
|
from .stamp_processing.detector import StampDetector |
|
from .stamp_processing.callback import remove_potiential_table_fp, remove_box_by_idc |
|
|
|
|
|
class Engine(f.Pipe): |
|
|
|
def __init__(self, cfg:Config=None, table_checkpoint=None, line_checkpoint=None, verbose=True): |
|
super().__init__( |
|
cfg = cfg or Config(), |
|
standard_width = cfg.STANDARD_WIDTH, |
|
preprocessor = Preprocessor(cfg.STANDARD_WIDTH, cfg.STANDARD_INTERPOLATION), |
|
post_processor = TableRestore(), |
|
table_detector = self._build_table_detector(cfg, table_checkpoint), |
|
stamp_detector = StampDetector(model_path=cfg.STAMP_DETECTION_CHECKPOINT, device=cfg.DEVICE) |
|
) |
|
self.stamp_detector.model.to(cfg.DEVICE) |
|
def run(self, image): |
|
|
|
|
|
tables = self.table_detector(image) |
|
|
|
return tables |
|
|
|
def _build_table_detector(self, cfg: Config, checkpoint=None, verbose=False): |
|
if checkpoint is None: |
|
checkpoint = hf.download_file(cfg.TABLE_DETECTION_CHECKPOINT) |
|
model = YOLO(checkpoint).to(cfg.DEVICE) |
|
if cfg.USE_STAMP_DETECTION: |
|
model.add_callback("on_predict_postprocess_end", self._stamp_detection_callback) |
|
|
|
return YOLORunner(model, cfg.YOLO_IMAGE_SIZE, cfg.CONF_THRESHOLD, cfg.HEIGHT_EXPAND_RATIO, cfg.DEVICE, verbose) |
|
|
|
def _stamp_detection_callback(self, predictor): |
|
|
|
assert len(predictor.results) == 1, 'Only support batch size 1' |
|
preds = predictor.results[0] |
|
remove_idc = remove_potiential_table_fp(self.stamp_detector, preds.orig_img, preds.boxes.xyxy, self.cfg.STAMP_REMOVING_IOU_THRESHOLD) |
|
stamp_removed_preds = remove_box_by_idc(preds.boxes, remove_idc) |
|
predictor.results[0].boxes = stamp_removed_preds |
|
|