from typing import Dict, List, Optional, Tuple import torch from detectron2.config import configurable from detectron2.structures import ImageList, Instances, Boxes from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN @META_ARCH_REGISTRY.register() class GRiT(GeneralizedRCNN): @configurable def __init__( self, **kwargs): super().__init__(**kwargs) assert self.proposal_generator is not None @classmethod def from_config(cls, cfg): ret = super().from_config(cfg) return ret def inference( self, batched_inputs: Tuple[Dict[str, torch.Tensor]], detected_instances: Optional[List[Instances]] = None, do_postprocess: bool = True, ): assert not self.training assert detected_instances is None images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) proposals, _ = self.proposal_generator(images, features, None) results, _ = self.roi_heads(features, proposals) if do_postprocess: assert not torch.jit.is_scripting(), \ "Scripting is not supported for postprocess." return GRiT._postprocess( results, batched_inputs, images.image_sizes) else: return results def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): if not self.training: return self.inference(batched_inputs) images = self.preprocess_image(batched_inputs) gt_instances = [x["instances"].to(self.device) for x in batched_inputs] targets_task = batched_inputs[0]['task'] for anno_per_image in batched_inputs: assert targets_task == anno_per_image['task'] features = self.backbone(images.tensor) proposals, proposal_losses = self.proposal_generator( images, features, gt_instances) proposals, roihead_textdecoder_losses = self.roi_heads( features, proposals, gt_instances, targets_task=targets_task) losses = {} losses.update(roihead_textdecoder_losses) losses.update(proposal_losses) return losses