File size: 2,209 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import tempfile

import pytest
import torch

from mmocr.models import build_detector


def _create_dummy_vocab_file(vocab_file):
    with open(vocab_file, 'w') as fw:
        for char in list(map(chr, range(ord('a'), ord('z') + 1))):
            fw.write(char + '\n')


def _get_config_module(fname):
    """Load a configuration as a python module."""
    from mmcv import Config
    config_mod = Config.fromfile(fname)
    return config_mod


def _get_detector_cfg(fname):
    """Grab configs necessary to create a detector.

    These are deep copied to allow for safe modification of parameters without
    influencing other tests.
    """
    config = _get_config_module(fname)
    model = copy.deepcopy(config.model)
    return model


@pytest.mark.parametrize(
    'cfg_file', ['configs/ner/bert_softmax/bert_softmax_cluener_18e.py'])
def test_bert_softmax(cfg_file):
    # prepare data
    texts = ['中'] * 47
    img = [31] * 47
    labels = [31] * 128
    input_ids = [0] * 128
    attention_mask = [0] * 128
    token_type_ids = [0] * 128
    img_metas = {
        'texts': texts,
        'labels': torch.tensor(labels).unsqueeze(0),
        'img': img,
        'input_ids': torch.tensor(input_ids).unsqueeze(0),
        'attention_masks': torch.tensor(attention_mask).unsqueeze(0),
        'token_type_ids': torch.tensor(token_type_ids).unsqueeze(0)
    }

    # create dummy data
    tmp_dir = tempfile.TemporaryDirectory()
    vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt')
    _create_dummy_vocab_file(vocab_file)

    model = _get_detector_cfg(cfg_file)
    model['label_convertor']['vocab_file'] = vocab_file

    detector = build_detector(model)
    losses = detector.forward(img, img_metas)
    assert isinstance(losses, dict)

    model['loss']['type'] = 'MaskedFocalLoss'
    detector = build_detector(model)
    losses = detector.forward(img, img_metas)
    assert isinstance(losses, dict)

    tmp_dir.cleanup()

    # Test forward test
    with torch.no_grad():
        batch_results = []
        result = detector.forward(None, img_metas, return_loss=False)
        batch_results.append(result)