File size: 1,997 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile

import numpy as np
import pytest

from mmocr.datasets.base_dataset import BaseDataset


def _create_dummy_ann_file(ann_file):
    ann_info1 = 'sample1.jpg hello'
    ann_info2 = 'sample2.jpg world'

    with open(ann_file, 'w') as fw:
        for ann_info in [ann_info1, ann_info2]:
            fw.write(ann_info + '\n')


def _create_dummy_loader():
    loader = dict(
        type='HardDiskLoader',
        repeat=1,
        parser=dict(type='LineStrParser', keys=['file_name', 'text']))
    return loader


def test_custom_dataset():
    tmp_dir = tempfile.TemporaryDirectory()
    # create dummy data
    ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
    _create_dummy_ann_file(ann_file)
    loader = _create_dummy_loader()

    for mode in [True, False]:
        dataset = BaseDataset(ann_file, loader, pipeline=[], test_mode=mode)

        # test len
        assert len(dataset) == len(dataset.data_infos)

        # test set group flag
        assert np.allclose(dataset.flag, [0, 0])

        # test prepare_train_img
        expect_results = {
            'img_info': {
                'file_name': 'sample1.jpg',
                'text': 'hello'
            },
            'img_prefix': ''
        }
        assert dataset.prepare_train_img(0) == expect_results

        # test prepare_test_img
        assert dataset.prepare_test_img(0) == expect_results

        # test __getitem__
        assert dataset[0] == expect_results

        # test get_next_index
        assert dataset._get_next_index(0) == 1

        # test format_resuls
        expect_results_copy = {
            key: value
            for key, value in expect_results.items()
        }
        dataset.format_results(expect_results)
        assert expect_results_copy == expect_results

        # test evaluate
        with pytest.raises(NotImplementedError):
            dataset.evaluate(expect_results)

    tmp_dir.cleanup()