EasyAnimate / easyanimate /video_caption /easyocr_detection_patched.py
bubbliiiing
Create Code
19fe404
raw
history blame
4.33 kB
"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py.
1. Disable DataParallel.
"""
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from PIL import Image
from collections import OrderedDict
import cv2
import numpy as np
from .craft_utils import getDetBoxes, adjustResultCoordinates
from .imgproc import resize_aspect_ratio, normalizeMeanVariance
from .craft import CRAFT
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith("module"):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = ".".join(k.split(".")[start_idx:])
new_state_dict[name] = v
return new_state_dict
def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False):
if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays
image_arrs = image
else: # image is single numpy array
image_arrs = [image]
img_resized_list = []
# resize
for img in image_arrs:
img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size,
interpolation=cv2.INTER_LINEAR,
mag_ratio=mag_ratio)
img_resized_list.append(img_resized)
ratio_h = ratio_w = 1 / target_ratio
# preprocessing
x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1))
for n_img in img_resized_list]
x = torch.from_numpy(np.array(x))
x = x.to(device)
# forward pass
with torch.no_grad():
y, feature = net(x)
boxes_list, polys_list = [], []
for out in y:
# make score and link map
score_text = out[:, :, 0].cpu().data.numpy()
score_link = out[:, :, 1].cpu().data.numpy()
# Post-processing
boxes, polys, mapper = getDetBoxes(
score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)
# coordinate adjustment
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
if estimate_num_chars:
boxes = list(boxes)
polys = list(polys)
for k in range(len(polys)):
if estimate_num_chars:
boxes[k] = (boxes[k], mapper[k])
if polys[k] is None:
polys[k] = boxes[k]
boxes_list.append(boxes)
polys_list.append(polys)
return boxes_list, polys_list
def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False):
net = CRAFT()
if device == 'cpu':
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
if quantize:
try:
torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True)
except:
pass
else:
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
# net = torch.nn.DataParallel(net).to(device)
net = net.to(device)
cudnn.benchmark = cudnn_benchmark
net.eval()
return net
def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs):
result = []
estimate_num_chars = optimal_num_chars is not None
bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector,
image, text_threshold,
link_threshold, low_text, poly,
device, estimate_num_chars)
if estimate_num_chars:
polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]
for polys in polys_list]
for polys in polys_list:
single_img_result = []
for i, box in enumerate(polys):
poly = np.array(box).astype(np.int32).reshape((-1))
single_img_result.append(poly)
result.append(single_img_result)
return result