Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import json | |
import os.path as osp | |
import tempfile | |
import torch | |
from mmocr.datasets.ner_dataset import NerDataset | |
from mmocr.models.ner.convertors.ner_convertor import NerConvertor | |
from mmocr.utils import list_to_file | |
def _create_dummy_ann_file(ann_file): | |
data = { | |
'text': '彭小军认为,国内银行现在走的是台湾的发卡模式', | |
'label': { | |
'address': { | |
'台湾': [[15, 16]] | |
}, | |
'name': { | |
'彭小军': [[0, 2]] | |
} | |
} | |
} | |
list_to_file(ann_file, [json.dumps(data, ensure_ascii=False)]) | |
def _create_dummy_vocab_file(vocab_file): | |
for char in list(map(chr, range(ord('a'), ord('z') + 1))): | |
list_to_file(vocab_file, [json.dumps(char + '\n', ensure_ascii=False)]) | |
def _create_dummy_loader(): | |
loader = dict( | |
type='HardDiskLoader', | |
repeat=1, | |
parser=dict(type='LineJsonParser', keys=['text', 'label'])) | |
return loader | |
def test_ner_dataset(): | |
# test initialization | |
loader = _create_dummy_loader() | |
categories = [ | |
'address', 'book', 'company', 'game', 'government', 'movie', 'name', | |
'organization', 'position', 'scene' | |
] | |
# create dummy data | |
tmp_dir = tempfile.TemporaryDirectory() | |
ann_file = osp.join(tmp_dir.name, 'fake_data.txt') | |
vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt') | |
_create_dummy_ann_file(ann_file) | |
_create_dummy_vocab_file(vocab_file) | |
max_len = 128 | |
ner_convertor = dict( | |
type='NerConvertor', | |
annotation_type='bio', | |
vocab_file=vocab_file, | |
categories=categories, | |
max_len=max_len) | |
test_pipeline = [ | |
dict( | |
type='NerTransform', | |
label_convertor=ner_convertor, | |
max_len=max_len), | |
dict(type='ToTensorNER') | |
] | |
dataset = NerDataset(ann_file, loader, pipeline=test_pipeline) | |
# test pre_pipeline | |
img_info = dataset.data_infos[0] | |
results = dict(img_info=img_info) | |
dataset.pre_pipeline(results) | |
# test prepare_train_img | |
dataset.prepare_train_img(0) | |
# test evaluation | |
result = [[['address', 15, 16], ['name', 0, 2]]] | |
dataset.evaluate(result) | |
# test pred convert2entity function | |
pred = [ | |
21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11, | |
21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, | |
11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21, | |
21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, | |
21, 21 | |
] | |
preds = [pred[:128]] | |
mask = [0] * 128 | |
for i in range(10): | |
mask[i] = 1 | |
assert len(preds[0]) == len(mask) | |
masks = torch.tensor([mask]) | |
convertor = NerConvertor( | |
annotation_type='bio', | |
vocab_file=vocab_file, | |
categories=categories, | |
max_len=128) | |
all_entities = convertor.convert_pred2entities(preds=preds, masks=masks) | |
assert len(all_entities[0][0]) == 3 | |
tmp_dir.cleanup() | |