File size: 3,683 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from argparse import ArgumentParser

import mmcv
from mmcv.utils import ProgressBar

from mmocr.apis import init_detector, model_inference
from mmocr.models import build_detector  # noqa: F401
from mmocr.utils import list_from_file, list_to_file


def gen_target_path(target_root_path, src_name, suffix):
    """Gen target file path.

    Args:
        target_root_path (str): The target root path.
        src_name (str): The source file name.
        suffix (str): The suffix of target file.
    """
    assert isinstance(target_root_path, str)
    assert isinstance(src_name, str)
    assert isinstance(suffix, str)

    file_name = osp.split(src_name)[-1]
    name = osp.splitext(file_name)[0]
    return osp.join(target_root_path, name + suffix)


def save_results(result, out_dir, img_name, score_thr=0.3):
    """Save result of detected bounding boxes (quadrangle or polygon) to txt
    file.

    Args:
        result (dict): Text Detection result for one image.
        img_name (str): Image file name.
        out_dir (str): Dir of txt files to save detected results.
        score_thr (float, optional): Score threshold to filter bboxes.
    """
    assert 'boundary_result' in result
    assert score_thr > 0 and score_thr < 1

    txt_file = gen_target_path(out_dir, img_name, '.txt')
    valid_boundary_res = [
        res for res in result['boundary_result'] if res[-1] > score_thr
    ]
    lines = [
        ','.join([str(round(x)) for x in row]) for row in valid_boundary_res
    ]
    list_to_file(txt_file, lines)


def main():
    parser = ArgumentParser()
    parser.add_argument('img_root', type=str, help='Image root path')
    parser.add_argument('img_list', type=str, help='Image path list file')
    parser.add_argument('config', type=str, help='Config file')
    parser.add_argument('checkpoint', type=str, help='Checkpoint file')
    parser.add_argument(
        '--score-thr', type=float, default=0.5, help='Bbox score threshold')
    parser.add_argument(
        '--out-dir',
        type=str,
        default='./results',
        help='Dir to save '
        'visualize images '
        'and bbox')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference.')
    args = parser.parse_args()

    assert 0 < args.score_thr < 1

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    if hasattr(model, 'module'):
        model = model.module

    # Start Inference
    out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
    mmcv.mkdir_or_exist(out_vis_dir)
    out_txt_dir = osp.join(args.out_dir, 'out_txt_dir')
    mmcv.mkdir_or_exist(out_txt_dir)

    lines = list_from_file(args.img_list)
    progressbar = ProgressBar(task_num=len(lines))
    for line in lines:
        progressbar.update()
        img_path = osp.join(args.img_root, line.strip())
        if not osp.exists(img_path):
            raise FileNotFoundError(img_path)
        # Test a single image
        result = model_inference(model, img_path)
        img_name = osp.basename(img_path)
        # save result
        save_results(result, out_txt_dir, img_name, score_thr=args.score_thr)
        # show result
        out_file = osp.join(out_vis_dir, img_name)
        kwargs_dict = {
            'score_thr': args.score_thr,
            'show': False,
            'out_file': out_file
        }
        model.show_result(img_path, result, **kwargs_dict)

    print(f'\nInference done, and results saved in {args.out_dir}\n')


if __name__ == '__main__':
    main()