Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Union | |
import torch | |
from mmdet.models.data_preprocessors import DetDataPreprocessor | |
from mmengine.structures import BaseDataElement | |
from mmyolo.registry import MODELS | |
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, | |
None] | |
class YOLOWDetDataPreprocessor(DetDataPreprocessor): | |
"""Rewrite collate_fn to get faster training speed. | |
Note: It must be used together with `mmyolo.datasets.utils.yolow_collate` | |
""" | |
def __init__(self, *args, non_blocking: Optional[bool] = True, **kwargs): | |
super().__init__(*args, non_blocking=non_blocking, **kwargs) | |
def forward(self, data: dict, training: bool = False) -> dict: | |
"""Perform normalization, padding and bgr2rgb conversion based on | |
``DetDataPreprocessorr``. | |
Args: | |
data (dict): Data sampled from dataloader. | |
training (bool): Whether to enable training time augmentation. | |
Returns: | |
dict: Data in the same format as the model input. | |
""" | |
if not training: | |
return super().forward(data, training) | |
data = self.cast_data(data) | |
inputs, data_samples = data['inputs'], data['data_samples'] | |
assert isinstance(data['data_samples'], dict) | |
# TODO: Supports multi-scale training | |
if self._channel_conversion and inputs.shape[1] == 3: | |
inputs = inputs[:, [2, 1, 0], ...] | |
if self._enable_normalize: | |
inputs = (inputs - self.mean) / self.std | |
if self.batch_augments is not None: | |
for batch_aug in self.batch_augments: | |
inputs, data_samples = batch_aug(inputs, data_samples) | |
img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs) | |
data_samples_output = { | |
'bboxes_labels': data_samples['bboxes_labels'], | |
'texts': data_samples['texts'], | |
'img_metas': img_metas | |
} | |
if 'masks' in data_samples: | |
data_samples_output['masks'] = data_samples['masks'] | |
if 'is_detection' in data_samples: | |
data_samples_output['is_detection'] = data_samples['is_detection'] | |
return {'inputs': inputs, 'data_samples': data_samples_output} | |