Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) Tencent Inc. All rights reserved. | |
from typing import List, Tuple, Union | |
from torch import Tensor | |
from mmdet.structures import OptSampleList, SampleList | |
from mmyolo.models.detectors import YOLODetector | |
from mmyolo.registry import MODELS | |
class YOLOWorldDetector(YOLODetector): | |
"""Implementation of YOLOW Series""" | |
def __init__(self, | |
*args, | |
mm_neck: bool = False, | |
num_train_classes=80, | |
num_test_classes=80, | |
**kwargs) -> None: | |
self.mm_neck = mm_neck | |
self.num_train_classes = num_train_classes | |
self.num_test_classes = num_test_classes | |
super().__init__(*args, **kwargs) | |
def loss(self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> Union[dict, list]: | |
"""Calculate losses from a batch of inputs and data samples.""" | |
self.bbox_head.num_classes = self.num_train_classes | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples) | |
return losses | |
def predict(self, | |
batch_inputs: Tensor, | |
batch_data_samples: SampleList, | |
rescale: bool = True) -> SampleList: | |
"""Predict results from a batch of inputs and data samples with post- | |
processing. | |
""" | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
# self.bbox_head.num_classes = self.num_test_classes | |
self.bbox_head.num_classes = txt_feats[0].shape[0] | |
results_list = self.bbox_head.predict(img_feats, | |
txt_feats, | |
batch_data_samples, | |
rescale=rescale) | |
batch_data_samples = self.add_pred_to_datasample( | |
batch_data_samples, results_list) | |
return batch_data_samples | |
def reparameterize(self, texts: List[List[str]]) -> None: | |
self.texts = texts | |
def _forward( | |
self, | |
batch_inputs: Tensor, | |
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: | |
"""Network forward process. Usually includes backbone, neck and head | |
forward without any post-processing. | |
""" | |
img_feats, txt_feats = self.extract_feat(batch_inputs, | |
batch_data_samples) | |
results = self.bbox_head.forward(img_feats, txt_feats) | |
return results | |
def extract_feat( | |
self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]: | |
"""Extract features.""" | |
if batch_data_samples is None: | |
texts = self.texts | |
elif isinstance(batch_data_samples, dict): | |
texts = batch_data_samples['texts'] | |
elif isinstance(batch_data_samples, list): | |
texts = [data_sample.texts for data_sample in batch_data_samples] | |
else: | |
raise TypeError('batch_data_samples should be dict or list.') | |
img_feats, txt_feats = self.backbone(batch_inputs, texts) | |
if self.with_neck: | |
if self.mm_neck: | |
img_feats = self.neck(img_feats, txt_feats) | |
else: | |
img_feats = self.neck(img_feats) | |
return img_feats, txt_feats | |