Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import mmcv | |
import numpy as np | |
from mmdet.datasets.api_wrappers import COCO | |
from mmdet.datasets.builder import DATASETS | |
from mmdet.datasets.coco import CocoDataset | |
import mmocr.utils as utils | |
from mmocr import digit_version | |
from mmocr.core.evaluation.hmean import eval_hmean | |
class IcdarDataset(CocoDataset): | |
"""Dataset for text detection while ann_file in coco format. | |
Args: | |
ann_file_backend (str): Storage backend for annotation file, | |
should be one in ['disk', 'petrel', 'http']. Default to 'disk'. | |
""" | |
CLASSES = ('text') | |
def __init__(self, | |
ann_file, | |
pipeline, | |
classes=None, | |
data_root=None, | |
img_prefix='', | |
seg_prefix=None, | |
proposal_file=None, | |
test_mode=False, | |
filter_empty_gt=True, | |
select_first_k=-1, | |
ann_file_backend='disk'): | |
# select first k images for fast debugging. | |
self.select_first_k = select_first_k | |
assert ann_file_backend in ['disk', 'petrel', 'http'] | |
self.ann_file_backend = ann_file_backend | |
super().__init__(ann_file, pipeline, classes, data_root, img_prefix, | |
seg_prefix, proposal_file, test_mode, filter_empty_gt) | |
def load_annotations(self, ann_file): | |
"""Load annotation from COCO style annotation file. | |
Args: | |
ann_file (str): Path of annotation file. | |
Returns: | |
list[dict]: Annotation info from COCO api. | |
""" | |
if self.ann_file_backend == 'disk': | |
self.coco = COCO(ann_file) | |
else: | |
mmcv_version = digit_version(mmcv.__version__) | |
if mmcv_version < digit_version('1.3.16'): | |
raise Exception('Please update mmcv to 1.3.16 or higher ' | |
'to enable "get_local_path" of "FileClient".') | |
file_client = mmcv.FileClient(backend=self.ann_file_backend) | |
with file_client.get_local_path(ann_file) as local_path: | |
self.coco = COCO(local_path) | |
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) | |
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} | |
self.img_ids = self.coco.get_img_ids() | |
data_infos = [] | |
count = 0 | |
for i in self.img_ids: | |
info = self.coco.load_imgs([i])[0] | |
info['filename'] = info['file_name'] | |
data_infos.append(info) | |
count = count + 1 | |
if count > self.select_first_k and self.select_first_k > 0: | |
break | |
return data_infos | |
def _parse_ann_info(self, img_info, ann_info): | |
"""Parse bbox and mask annotation. | |
Args: | |
ann_info (list[dict]): Annotation info of an image. | |
Returns: | |
dict: A dict containing the following keys: bboxes, bboxes_ignore, | |
labels, masks, masks_ignore, seg_map. "masks" and | |
"masks_ignore" are represented by polygon boundary | |
point sequences. | |
""" | |
gt_bboxes = [] | |
gt_labels = [] | |
gt_bboxes_ignore = [] | |
gt_masks_ignore = [] | |
gt_masks_ann = [] | |
for ann in ann_info: | |
if ann.get('ignore', False): | |
continue | |
x1, y1, w, h = ann['bbox'] | |
if ann['area'] <= 0 or w < 1 or h < 1: | |
continue | |
if ann['category_id'] not in self.cat_ids: | |
continue | |
bbox = [x1, y1, x1 + w, y1 + h] | |
if ann.get('iscrowd', False): | |
gt_bboxes_ignore.append(bbox) | |
gt_masks_ignore.append(ann.get( | |
'segmentation', None)) # to float32 for latter processing | |
else: | |
gt_bboxes.append(bbox) | |
gt_labels.append(self.cat2label[ann['category_id']]) | |
gt_masks_ann.append(ann.get('segmentation', None)) | |
if gt_bboxes: | |
gt_bboxes = np.array(gt_bboxes, dtype=np.float32) | |
gt_labels = np.array(gt_labels, dtype=np.int64) | |
else: | |
gt_bboxes = np.zeros((0, 4), dtype=np.float32) | |
gt_labels = np.array([], dtype=np.int64) | |
if gt_bboxes_ignore: | |
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) | |
else: | |
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) | |
seg_map = img_info['filename'].replace('jpg', 'png') | |
ann = dict( | |
bboxes=gt_bboxes, | |
labels=gt_labels, | |
bboxes_ignore=gt_bboxes_ignore, | |
masks_ignore=gt_masks_ignore, | |
masks=gt_masks_ann, | |
seg_map=seg_map) | |
return ann | |
def evaluate(self, | |
results, | |
metric='hmean-iou', | |
logger=None, | |
score_thr=0.3, | |
rank_list=None, | |
**kwargs): | |
"""Evaluate the hmean metric. | |
Args: | |
results (list[dict]): Testing results of the dataset. | |
metric (str | list[str]): Metrics to be evaluated. | |
logger (logging.Logger | str | None): Logger used for printing | |
related information during evaluation. Default: None. | |
rank_list (str): json file used to save eval result | |
of each image after ranking. | |
Returns: | |
dict[dict[str: float]]: The evaluation results. | |
""" | |
assert utils.is_type_list(results, dict) | |
metrics = metric if isinstance(metric, list) else [metric] | |
allowed_metrics = ['hmean-iou', 'hmean-ic13'] | |
metrics = set(metrics) & set(allowed_metrics) | |
img_infos = [] | |
ann_infos = [] | |
for i in range(len(self)): | |
img_info = {'filename': self.data_infos[i]['file_name']} | |
img_infos.append(img_info) | |
ann_infos.append(self.get_ann_info(i)) | |
eval_results = eval_hmean( | |
results, | |
img_infos, | |
ann_infos, | |
metrics=metrics, | |
score_thr=score_thr, | |
logger=logger, | |
rank_list=rank_list) | |
return eval_results | |