Spaces:
Running
Running
"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py. | |
1. Disable DataParallel. | |
""" | |
import torch | |
import torch.backends.cudnn as cudnn | |
from torch.autograd import Variable | |
from PIL import Image | |
from collections import OrderedDict | |
import cv2 | |
import numpy as np | |
from .craft_utils import getDetBoxes, adjustResultCoordinates | |
from .imgproc import resize_aspect_ratio, normalizeMeanVariance | |
from .craft import CRAFT | |
def copyStateDict(state_dict): | |
if list(state_dict.keys())[0].startswith("module"): | |
start_idx = 1 | |
else: | |
start_idx = 0 | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = ".".join(k.split(".")[start_idx:]) | |
new_state_dict[name] = v | |
return new_state_dict | |
def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False): | |
if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays | |
image_arrs = image | |
else: # image is single numpy array | |
image_arrs = [image] | |
img_resized_list = [] | |
# resize | |
for img in image_arrs: | |
img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size, | |
interpolation=cv2.INTER_LINEAR, | |
mag_ratio=mag_ratio) | |
img_resized_list.append(img_resized) | |
ratio_h = ratio_w = 1 / target_ratio | |
# preprocessing | |
x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1)) | |
for n_img in img_resized_list] | |
x = torch.from_numpy(np.array(x)) | |
x = x.to(device) | |
# forward pass | |
with torch.no_grad(): | |
y, feature = net(x) | |
boxes_list, polys_list = [], [] | |
for out in y: | |
# make score and link map | |
score_text = out[:, :, 0].cpu().data.numpy() | |
score_link = out[:, :, 1].cpu().data.numpy() | |
# Post-processing | |
boxes, polys, mapper = getDetBoxes( | |
score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) | |
# coordinate adjustment | |
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) | |
polys = adjustResultCoordinates(polys, ratio_w, ratio_h) | |
if estimate_num_chars: | |
boxes = list(boxes) | |
polys = list(polys) | |
for k in range(len(polys)): | |
if estimate_num_chars: | |
boxes[k] = (boxes[k], mapper[k]) | |
if polys[k] is None: | |
polys[k] = boxes[k] | |
boxes_list.append(boxes) | |
polys_list.append(polys) | |
return boxes_list, polys_list | |
def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False): | |
net = CRAFT() | |
if device == 'cpu': | |
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) | |
if quantize: | |
try: | |
torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) | |
except: | |
pass | |
else: | |
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) | |
# net = torch.nn.DataParallel(net).to(device) | |
net = net.to(device) | |
cudnn.benchmark = cudnn_benchmark | |
net.eval() | |
return net | |
def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs): | |
result = [] | |
estimate_num_chars = optimal_num_chars is not None | |
bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector, | |
image, text_threshold, | |
link_threshold, low_text, poly, | |
device, estimate_num_chars) | |
if estimate_num_chars: | |
polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))] | |
for polys in polys_list] | |
for polys in polys_list: | |
single_img_result = [] | |
for i, box in enumerate(polys): | |
poly = np.array(box).astype(np.int32).reshape((-1)) | |
single_img_result.append(poly) | |
result.append(single_img_result) | |
return result | |