Spaces:
Runtime error
Runtime error
File size: 1,665 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets.builder import DATASETS
from mmocr.core.evaluation.ner_metric import eval_ner_f1
from mmocr.datasets.base_dataset import BaseDataset
@DATASETS.register_module()
class NerDataset(BaseDataset):
"""Custom dataset for named entity recognition tasks.
Args:
ann_file (txt): Annotation file path.
loader (dict): Dictionary to construct loader
to load annotation infos.
pipeline (list[dict]): Processing pipeline.
test_mode (bool, optional): If True, try...except will
be turned off in __getitem__.
"""
def prepare_train_img(self, index):
"""Get training data and annotations after pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
introduced by pipeline.
"""
ann_info = self.data_infos[index]
return self.pipeline(ann_info)
def evaluate(self, results, metric=None, logger=None, **kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
info (dict): A dict containing the following keys:
'acc', 'recall', 'f1-score'.
"""
gt_infos = list(self.data_infos)
eval_results = eval_ner_f1(results, gt_infos)
return eval_results
|