camenduru's picture
thanks to show ❤
3bbb319
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from mmdet.core import voc_classes
label_ids = {name: i for i, name in enumerate(voc_classes())}
def parse_xml(args):
xml_path, img_path = args
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in root.findall('object'):
name = obj.find('name').text
label = label_ids[name]
difficult = int(obj.find('difficult').text)
bnd_box = obj.find('bndbox')
bbox = [
int(bnd_box.find('xmin').text),
int(bnd_box.find('ymin').text),
int(bnd_box.find('xmax').text),
int(bnd_box.find('ymax').text)
]
if difficult:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
labels_ignore = np.array(labels_ignore)
annotation = {
'filename': img_path,
'width': w,
'height': h,
'ann': {
'bboxes': bboxes.astype(np.float32),
'labels': labels.astype(np.int64),
'bboxes_ignore': bboxes_ignore.astype(np.float32),
'labels_ignore': labels_ignore.astype(np.int64)
}
}
return annotation
def cvt_annotations(devkit_path, years, split, out_file):
if not isinstance(years, list):
years = [years]
annotations = []
for year in years:
filelist = osp.join(devkit_path,
f'VOC{year}/ImageSets/Main/{split}.txt')
if not osp.isfile(filelist):
print(f'filelist does not exist: {filelist}, '
f'skip voc{year} {split}')
return
img_names = mmcv.list_from_file(filelist)
xml_paths = [
osp.join(devkit_path, f'VOC{year}/Annotations/{img_name}.xml')
for img_name in img_names
]
img_paths = [
f'VOC{year}/JPEGImages/{img_name}.jpg' for img_name in img_names
]
part_annotations = mmcv.track_progress(parse_xml,
list(zip(xml_paths, img_paths)))
annotations.extend(part_annotations)
if out_file.endswith('json'):
annotations = cvt_to_coco_json(annotations)
mmcv.dump(annotations, out_file)
return annotations
def cvt_to_coco_json(annotations):
image_id = 0
annotation_id = 0
coco = dict()
coco['images'] = []
coco['type'] = 'instance'
coco['categories'] = []
coco['annotations'] = []
image_set = set()
def addAnnItem(annotation_id, image_id, category_id, bbox, difficult_flag):
annotation_item = dict()
annotation_item['segmentation'] = []
seg = []
# bbox[] is x1,y1,x2,y2
# left_top
seg.append(int(bbox[0]))
seg.append(int(bbox[1]))
# left_bottom
seg.append(int(bbox[0]))
seg.append(int(bbox[3]))
# right_bottom
seg.append(int(bbox[2]))
seg.append(int(bbox[3]))
# right_top
seg.append(int(bbox[2]))
seg.append(int(bbox[1]))
annotation_item['segmentation'].append(seg)
xywh = np.array(
[bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]])
annotation_item['area'] = int(xywh[2] * xywh[3])
if difficult_flag == 1:
annotation_item['ignore'] = 0
annotation_item['iscrowd'] = 1
else:
annotation_item['ignore'] = 0
annotation_item['iscrowd'] = 0
annotation_item['image_id'] = int(image_id)
annotation_item['bbox'] = xywh.astype(int).tolist()
annotation_item['category_id'] = int(category_id)
annotation_item['id'] = int(annotation_id)
coco['annotations'].append(annotation_item)
return annotation_id + 1
for category_id, name in enumerate(voc_classes()):
category_item = dict()
category_item['supercategory'] = str('none')
category_item['id'] = int(category_id)
category_item['name'] = str(name)
coco['categories'].append(category_item)
for ann_dict in annotations:
file_name = ann_dict['filename']
ann = ann_dict['ann']
assert file_name not in image_set
image_item = dict()
image_item['id'] = int(image_id)
image_item['file_name'] = str(file_name)
image_item['height'] = int(ann_dict['height'])
image_item['width'] = int(ann_dict['width'])
coco['images'].append(image_item)
image_set.add(file_name)
bboxes = ann['bboxes'][:, :4]
labels = ann['labels']
for bbox_id in range(len(bboxes)):
bbox = bboxes[bbox_id]
label = labels[bbox_id]
annotation_id = addAnnItem(
annotation_id, image_id, label, bbox, difficult_flag=0)
bboxes_ignore = ann['bboxes_ignore'][:, :4]
labels_ignore = ann['labels_ignore']
for bbox_id in range(len(bboxes_ignore)):
bbox = bboxes_ignore[bbox_id]
label = labels_ignore[bbox_id]
annotation_id = addAnnItem(
annotation_id, image_id, label, bbox, difficult_flag=1)
image_id += 1
return coco
def parse_args():
parser = argparse.ArgumentParser(
description='Convert PASCAL VOC annotations to mmdetection format')
parser.add_argument('devkit_path', help='pascal voc devkit path')
parser.add_argument('-o', '--out-dir', help='output path')
parser.add_argument(
'--out-format',
default='pkl',
choices=('pkl', 'coco'),
help='output format, "coco" indicates coco annotation format')
args = parser.parse_args()
return args
def main():
args = parse_args()
devkit_path = args.devkit_path
out_dir = args.out_dir if args.out_dir else devkit_path
mmcv.mkdir_or_exist(out_dir)
years = []
if osp.isdir(osp.join(devkit_path, 'VOC2007')):
years.append('2007')
if osp.isdir(osp.join(devkit_path, 'VOC2012')):
years.append('2012')
if '2007' in years and '2012' in years:
years.append(['2007', '2012'])
if not years:
raise IOError(f'The devkit path {devkit_path} contains neither '
'"VOC2007" nor "VOC2012" subfolder')
out_fmt = f'.{args.out_format}'
if args.out_format == 'coco':
out_fmt = '.json'
for year in years:
if year == '2007':
prefix = 'voc07'
elif year == '2012':
prefix = 'voc12'
elif year == ['2007', '2012']:
prefix = 'voc0712'
for split in ['train', 'val', 'trainval']:
dataset_name = prefix + '_' + split
print(f'processing {dataset_name} ...')
cvt_annotations(devkit_path, year, split,
osp.join(out_dir, dataset_name + out_fmt))
if not isinstance(year, list):
dataset_name = prefix + '_test'
print(f'processing {dataset_name} ...')
cvt_annotations(devkit_path, year, 'test',
osp.join(out_dir, dataset_name + out_fmt))
print('Done!')
if __name__ == '__main__':
main()