File size: 3,789 Bytes
45099b6 adc3ff1 45099b6 adc3ff1 45099b6 adc3ff1 45099b6 adc3ff1 45099b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os
from typing import List, Union
import numpy as np
import numpy.typing as npt
from .detector import StampDetector
from .module.unet import *
from .preprocess import create_batch
from .utils import REMOVER_WEIGHT_ID, check_image_shape, download_weight
class StampRemover:
def __init__(
self, detection_weight: Union[str, None] = None, removal_weight: Union[str, None] = None, device: str = "cpu"
):
"""Create an object to remove stamps from document images"""
# assert device == "cpu", "Currently only support cpu inference"
if removal_weight is None:
if not os.path.exists("tmp/"):
os.makedirs("tmp/", exist_ok=True)
removal_weight = os.path.join("tmp", "stamp_remover.pkl")
print("Downloading stamp remover weight from google drive")
download_weight(REMOVER_WEIGHT_ID, output=removal_weight)
print(f"Finished downloading. Weight is saved at {removal_weight}")
try:
self.remover = UnetInference(removal_weight) # type: ignore
except Exception as e:
print(e)
print("There is something wrong when loading remover weight")
print(
"Please make sure you provide the correct path to the weight"
"or mannually download the weight at"
f"https://drive.google.com/file/d/{REMOVER_WEIGHT_ID}/view?usp=sharing"
)
raise FileNotFoundError()
self.detector = StampDetector(detection_weight, device=device)
self.padding = 3
def __call__(self, image_list: Union[List[npt.NDArray], npt.NDArray], batch_size: int = 16) -> List[npt.NDArray]:
"""Detect and remove stamps from document images
Args:
image_list (Union[List[npt.NDArray], npt.NDArray]): list of input images
batch_size (int, optional): Defaults to 16.
Returns:
List[np.ndarray]: Input images with stamps removed
"""
if not isinstance(image_list, (np.ndarray, list)):
raise TypeError("Invalid Type: Input must be of type list or np.ndarray")
if len(image_list) > 0:
check_image_shape(image_list[0])
else:
return []
return self.__batch_removing(image_list, batch_size) # type:ignore
def __batch_removing(self, image_list, batch_size=16): # type: ignore
new_pages = []
shapes = set(list(x.shape for x in image_list))
images_batch, indices = create_batch(image_list, shapes, batch_size)
# num_batch = len(image_list) // batch_size
detection_predictions = []
for batch in images_batch:
if len(batch):
detection_predictions.extend(self.detector(batch))
z = zip(detection_predictions, indices)
sorted_result = sorted(z, key=lambda x: x[1])
detection_predictions, _ = zip(*sorted_result)
for idx, page_boxes in enumerate(detection_predictions):
page_img = image_list[idx]
h, w, c = page_img.shape
for box in page_boxes:
x_min, y_min, x_max, y_max = box[:4]
stamp_area = page_img[
max(y_min - self.padding, 0) : min(y_max + self.padding, h),
max(x_min - self.padding, 0) : min(x_max + self.padding, w),
]
stamp_area = self.remover([stamp_area]) # type:ignore
page_img[
max(y_min - self.padding, 0) : min(y_max + self.padding, h),
max(x_min - self.padding, 0) : min(x_max + self.padding, w),
:,
] = stamp_area[0]
new_pages.append(page_img)
return new_pages
|