File size: 5,784 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# 教程 3: 自定义数据预处理流程
## 数据流程的设计
按照惯例,我们使用 `Dataset` 和 `DataLoader` 进行多进程的数据加载。`Dataset` 返回字典类型的数据,数据内容为模型 `forward` 方法的各个参数。由于在目标检测中,输入的图像数据具有不同的大小,我们在 `MMCV` 里引入一个新的 `DataContainer` 类去收集和分发不同大小的输入数据。更多细节请参考[这里](https://github.com/open-mmlab/mmcv/blob/master/mmcv/parallel/data_container.py)。
数据的准备流程和数据集是解耦的。通常一个数据集定义了如何处理标注数据(annotations)信息,而一个数据流程定义了准备一个数据字典的所有步骤。一个流程包括一系列的操作,每个操作都把一个字典作为输入,然后再输出一个新的字典给下一个变换操作。
我们在下图展示了一个经典的数据处理流程。蓝色块是数据处理操作,随着数据流程的处理,每个操作都可以在结果字典中加入新的键(标记为绿色)或更新现有的键(标记为橙色)。
![pipeline figure](../../../resources/data_pipeline.png)
这些操作可以分为数据加载(data loading)、预处理(pre-processing)、格式变化(formatting)和测试时数据增强(test-time augmentation)。
下面的例子是 `Faster R-CNN` 的一个流程:
```python
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
```
对于每个操作,我们列出它添加、更新、移除的相关字典域 (dict fields):
### 数据加载 Data loading
`LoadImageFromFile`
- 增加:img, img_shape, ori_shape
`LoadAnnotations`
- 增加:gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks, gt_semantic_seg, bbox_fields, mask_fields
`LoadProposals`
- 增加:proposals
### 预处理 Pre-processing
`Resize`
- 增加:scale, scale_idx, pad_shape, scale_factor, keep_ratio
- 更新:img, img_shape, \*bbox_fields, \*mask_fields, \*seg_fields
`RandomFlip`
- 增加:flip
- 更新:img, \*bbox_fields, \*mask_fields, \*seg_fields
`Pad`
- 增加:pad_fixed_size, pad_size_divisor
- 更新:img, pad_shape, \*mask_fields, \*seg_fields
`RandomCrop`
- 更新:img, pad_shape, gt_bboxes, gt_labels, gt_masks, \*bbox_fields
`Normalize`
- 增加:img_norm_cfg
- 更新:img
`SegRescale`
- 更新:gt_semantic_seg
`PhotoMetricDistortion`
- 更新:img
`Expand`
- 更新:img, gt_bboxes
`MinIoURandomCrop`
- 更新:img, gt_bboxes, gt_labels
`Corrupt`
- 更新:img
### 格式 Formatting
`ToTensor`
- 更新:由 `keys` 指定
`ImageToTensor`
- 更新:由 `keys` 指定
`Transpose`
- 更新:由 `keys` 指定
`ToDataContainer`
- 更新:由 `keys` 指定
`DefaultFormatBundle`
- 更新:img, proposals, gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks, gt_semantic_seg
`Collect`
- 增加:img_metas(img_metas 的键(key)被 `meta_keys` 指定)
- 移除:除了 `keys` 指定的键(key)之外的所有其他的键(key)
### 测试时数据增强 Test time augmentation
`MultiScaleFlipAug`
## 拓展和使用自定义的流程
1. 在任意文件里写一个新的流程,例如在 `my_pipeline.py`,它以一个字典作为输入并且输出一个字典:
```python
import random
from mmdet.datasets import PIPELINES
@PIPELINES.register_module()
class MyTransform:
"""Add your transform
Args:
p (float): Probability of shifts. Default 0.5.
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, results):
if random.random() > self.p:
results['dummy'] = True
return results
```
2. 在配置文件里调用并使用你写的数据处理流程,需要确保你的训练脚本能够正确导入新增模块:
```python
custom_imports = dict(imports=['path.to.my_pipeline'], allow_failed_imports=False)
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='MyTransform', p=0.2),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
```
3. 可视化数据增强处理流程的结果
如果想要可视化数据增强处理流程的结果,可以使用 `tools/misc/browse_dataset.py` 直观
地浏览检测数据集(图像和标注信息),或将图像保存到指定目录。
使用方法请参考[日志分析](../useful_tools.md)
|