#!/usr/bin/env python # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import shutil import time from argparse import ArgumentParser from itertools import compress import mmcv from mmcv.utils import ProgressBar from mmocr.apis import init_detector, model_inference from mmocr.core.evaluation.ocr_metric import eval_ocr_metric from mmocr.datasets import build_dataset # noqa: F401 from mmocr.models import build_detector # noqa: F401 from mmocr.utils import get_root_logger, list_from_file, list_to_file def save_results(img_paths, pred_labels, gt_labels, res_dir): """Save predicted results to txt file. Args: img_paths (list[str]) pred_labels (list[str]) gt_labels (list[str]) res_dir (str) """ assert len(img_paths) == len(pred_labels) == len(gt_labels) corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)] wrongs = [not c for c in corrects] lines = [ f'{img} {pred} {gt}' for img, pred, gt in zip(img_paths, pred_labels, gt_labels) ] list_to_file(osp.join(res_dir, 'results.txt'), lines) list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects)) list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs)) def main(): parser = ArgumentParser() parser.add_argument('img_root_path', type=str, help='Image root path') parser.add_argument('img_list', type=str, help='Image path list file') parser.add_argument('config', type=str, help='Config file') parser.add_argument('checkpoint', type=str, help='Checkpoint file') parser.add_argument( '--out_dir', type=str, default='./results', help='Dir to save results') parser.add_argument( '--show', action='store_true', help='show image or save') parser.add_argument( '--device', default='cuda:0', help='Device used for inference.') args = parser.parse_args() # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(args.out_dir, f'{timestamp}.log') logger = get_root_logger(log_file=log_file, log_level='INFO') # build the model from a config file and a checkpoint file model = init_detector(args.config, args.checkpoint, device=args.device) if hasattr(model, 'module'): model = model.module # Start Inference out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') mmcv.mkdir_or_exist(out_vis_dir) correct_vis_dir = osp.join(args.out_dir, 'correct') mmcv.mkdir_or_exist(correct_vis_dir) wrong_vis_dir = osp.join(args.out_dir, 'wrong') mmcv.mkdir_or_exist(wrong_vis_dir) img_paths, pred_labels, gt_labels = [], [], [] lines = list_from_file(args.img_list) progressbar = ProgressBar(task_num=len(lines)) num_gt_label = 0 for line in lines: progressbar.update() item_list = line.strip().split() img_file = item_list[0] gt_label = '' if len(item_list) >= 2: gt_label = item_list[1] num_gt_label += 1 img_path = osp.join(args.img_root_path, img_file) if not osp.exists(img_path): raise FileNotFoundError(img_path) # Test a single image result = model_inference(model, img_path) pred_label = result['text'] out_img_name = '_'.join(img_file.split('/')) out_file = osp.join(out_vis_dir, out_img_name) kwargs_dict = { 'gt_label': gt_label, 'show': args.show, 'out_file': '' if args.show else out_file } model.show_result(img_path, result, **kwargs_dict) if gt_label != '': if gt_label == pred_label: dst_file = osp.join(correct_vis_dir, out_img_name) else: dst_file = osp.join(wrong_vis_dir, out_img_name) shutil.copy(out_file, dst_file) img_paths.append(img_path) gt_labels.append(gt_label) pred_labels.append(pred_label) # Save results save_results(img_paths, pred_labels, gt_labels, args.out_dir) if num_gt_label == len(pred_labels): # eval eval_results = eval_ocr_metric(pred_labels, gt_labels) logger.info('\n' + '-' * 100) info = ('eval on testset with img_root_path ' f'{args.img_root_path} and img_list {args.img_list}\n') logger.info(info) logger.info(eval_results) print(f'\nInference done, and results saved in {args.out_dir}\n') if __name__ == '__main__': main()