Spaces:
Build error
Build error
File size: 2,269 Bytes
5a444be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from typing import Dict, List, Optional, Tuple
import torch
from detectron2.config import configurable
from detectron2.structures import ImageList, Instances, Boxes
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
@META_ARCH_REGISTRY.register()
class GRiT(GeneralizedRCNN):
@configurable
def __init__(
self,
**kwargs):
super().__init__(**kwargs)
assert self.proposal_generator is not None
@classmethod
def from_config(cls, cfg):
ret = super().from_config(cfg)
return ret
def inference(
self,
batched_inputs: Tuple[Dict[str, torch.Tensor]],
detected_instances: Optional[List[Instances]] = None,
do_postprocess: bool = True,
):
assert not self.training
assert detected_instances is None
images = self.preprocess_image(batched_inputs)
features = self.backbone(images.tensor)
proposals, _ = self.proposal_generator(images, features, None)
results, _ = self.roi_heads(features, proposals)
if do_postprocess:
assert not torch.jit.is_scripting(), \
"Scripting is not supported for postprocess."
return GRiT._postprocess(
results, batched_inputs, images.image_sizes)
else:
return results
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
if not self.training:
return self.inference(batched_inputs)
images = self.preprocess_image(batched_inputs)
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
targets_task = batched_inputs[0]['task']
for anno_per_image in batched_inputs:
assert targets_task == anno_per_image['task']
features = self.backbone(images.tensor)
proposals, proposal_losses = self.proposal_generator(
images, features, gt_instances)
proposals, roihead_textdecoder_losses = self.roi_heads(
features, proposals, gt_instances, targets_task=targets_task)
losses = {}
losses.update(roihead_textdecoder_losses)
losses.update(proposal_losses)
return losses |