MMOCR / tools /recog_test_imgs.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
4.56 kB
#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
import time
from argparse import ArgumentParser
from itertools import compress
import mmcv
from mmcv.utils import ProgressBar
from mmocr.apis import init_detector, model_inference
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
from mmocr.utils import get_root_logger, list_from_file, list_to_file
def save_results(img_paths, pred_labels, gt_labels, res_dir):
"""Save predicted results to txt file.
Args:
img_paths (list[str])
pred_labels (list[str])
gt_labels (list[str])
res_dir (str)
"""
assert len(img_paths) == len(pred_labels) == len(gt_labels)
corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)]
wrongs = [not c for c in corrects]
lines = [
f'{img} {pred} {gt}'
for img, pred, gt in zip(img_paths, pred_labels, gt_labels)
]
list_to_file(osp.join(res_dir, 'results.txt'), lines)
list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects))
list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs))
def main():
parser = ArgumentParser()
parser.add_argument('img_root_path', type=str, help='Image root path')
parser.add_argument('img_list', type=str, help='Image path list file')
parser.add_argument('config', type=str, help='Config file')
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
parser.add_argument(
'--out_dir', type=str, default='./results', help='Dir to save results')
parser.add_argument(
'--show', action='store_true', help='show image or save')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
args = parser.parse_args()
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(args.out_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level='INFO')
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
if hasattr(model, 'module'):
model = model.module
# Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
mmcv.mkdir_or_exist(out_vis_dir)
correct_vis_dir = osp.join(args.out_dir, 'correct')
mmcv.mkdir_or_exist(correct_vis_dir)
wrong_vis_dir = osp.join(args.out_dir, 'wrong')
mmcv.mkdir_or_exist(wrong_vis_dir)
img_paths, pred_labels, gt_labels = [], [], []
lines = list_from_file(args.img_list)
progressbar = ProgressBar(task_num=len(lines))
num_gt_label = 0
for line in lines:
progressbar.update()
item_list = line.strip().split()
img_file = item_list[0]
gt_label = ''
if len(item_list) >= 2:
gt_label = item_list[1]
num_gt_label += 1
img_path = osp.join(args.img_root_path, img_file)
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = model_inference(model, img_path)
pred_label = result['text']
out_img_name = '_'.join(img_file.split('/'))
out_file = osp.join(out_vis_dir, out_img_name)
kwargs_dict = {
'gt_label': gt_label,
'show': args.show,
'out_file': '' if args.show else out_file
}
model.show_result(img_path, result, **kwargs_dict)
if gt_label != '':
if gt_label == pred_label:
dst_file = osp.join(correct_vis_dir, out_img_name)
else:
dst_file = osp.join(wrong_vis_dir, out_img_name)
shutil.copy(out_file, dst_file)
img_paths.append(img_path)
gt_labels.append(gt_label)
pred_labels.append(pred_label)
# Save results
save_results(img_paths, pred_labels, gt_labels, args.out_dir)
if num_gt_label == len(pred_labels):
# eval
eval_results = eval_ocr_metric(pred_labels, gt_labels)
logger.info('\n' + '-' * 100)
info = ('eval on testset with img_root_path '
f'{args.img_root_path} and img_list {args.img_list}\n')
logger.info(info)
logger.info(eval_results)
print(f'\nInference done, and results saved in {args.out_dir}\n')
if __name__ == '__main__':
main()