|
from typing import List, OrderedDict, Tuple |
|
import warnings |
|
import numpy as np |
|
import pandas as pd |
|
import cv2 |
|
import os |
|
from torch.nn.modules.conv import Conv2d |
|
from torch.utils.data.dataset import ConcatDataset |
|
from tqdm import tqdm |
|
import argparse |
|
from torch.utils.data import Dataset,DataLoader |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import models |
|
import detection.transforms as transforms |
|
import torchvision.transforms as T |
|
import detection.utils as utils |
|
import torch.nn.functional as F |
|
import shutil |
|
import json |
|
from detection.engine import train_one_epoch, evaluate |
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor |
|
import torch.multiprocessing |
|
import copy |
|
from torchvision.ops import MultiScaleRoIAlign |
|
from torchvision.models.detection.roi_heads import RoIHeads |
|
|
|
|
|
|
|
|
|
|
|
def get_FRCNN_model(num_classes=1): |
|
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True,trainable_backbone_layers=3,min_size=1800,max_size=3600,image_std=(1.0,1.0,1.0),box_score_thresh=0.001) |
|
|
|
in_features = model.roi_heads.box_predictor.cls_score.in_features |
|
|
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1) |
|
return model |
|
|
|
|
|
|
|
class RoIpool(nn.Module): |
|
|
|
def __init__(self,pool): |
|
super().__init__() |
|
self.box_roi_pool1 = copy.deepcopy(pool) |
|
self.box_roi_pool2 = copy.deepcopy(pool) |
|
|
|
|
|
def forward(self,features,proposals,image_shapes): |
|
x = self.box_roi_pool1(features[0],proposals,image_shapes) |
|
y = self.box_roi_pool2(features[1],proposals,image_shapes) |
|
z = torch.cat((x,y),dim=1) |
|
return z |
|
|
|
class TwoMLPHead(nn.Module): |
|
""" |
|
Standard heads for FPN-based models |
|
Args: |
|
in_channels (int): number of input channels |
|
representation_size (int): size of the intermediate representation |
|
""" |
|
|
|
def __init__(self, in_channels=None, representation_size=None): |
|
super().__init__() |
|
|
|
self.fc6 = nn.Linear(in_channels, representation_size) |
|
self.fc7 = nn.Linear(representation_size, representation_size) |
|
|
|
def forward(self, x): |
|
x = x.flatten(start_dim=1) |
|
|
|
x = F.relu(self.fc6(x)) |
|
x = F.relu(self.fc7(x)) |
|
return x |
|
|
|
|
|
|
|
class Bilateral_model(nn.Module): |
|
|
|
def __init__(self,frcnn_model): |
|
super().__init__() |
|
self.frcnn = frcnn_model |
|
self.transform = copy.deepcopy(frcnn_model.transform) |
|
self.backbone1 = copy.deepcopy(frcnn_model.backbone) |
|
self.backbone2 = copy.deepcopy(frcnn_model.backbone) |
|
self.rpn = copy.deepcopy(frcnn_model.rpn) |
|
for param in self.rpn.parameters(): |
|
param.requires_grad = False |
|
for param in self.backbone1.parameters(): |
|
param.requires_grad = False |
|
for param in self.backbone2.parameters(): |
|
param.requires_grad = False |
|
box_roi_pool = RoIpool(frcnn_model.roi_heads.box_roi_pool) |
|
box_head = TwoMLPHead(512*7*7,1024) |
|
box_predictor = copy.deepcopy(frcnn_model.roi_heads.box_predictor) |
|
box_score_thresh=0.001 |
|
box_nms_thresh=0.5 |
|
box_detections_per_img=100 |
|
box_fg_iou_thresh=0.5 |
|
box_bg_iou_thresh=0.5 |
|
box_batch_size_per_image=512 |
|
box_positive_fraction=0.25 |
|
bbox_reg_weights=None |
|
self.roi_heads = RoIHeads( |
|
|
|
box_roi_pool, |
|
box_head, |
|
box_predictor, |
|
box_fg_iou_thresh, |
|
box_bg_iou_thresh, |
|
box_batch_size_per_image, |
|
box_positive_fraction, |
|
bbox_reg_weights, |
|
box_score_thresh, |
|
box_nms_thresh, |
|
box_detections_per_img, |
|
) |
|
|
|
@torch.jit.unused |
|
def eager_outputs(self, losses, detections): |
|
if self.training: |
|
return losses |
|
|
|
return detections |
|
|
|
|
|
def forward(self, images, targets=None): |
|
""" |
|
Args: |
|
images (list[Tensor(tuples)]): images to be processed |
|
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) |
|
Returns: |
|
result (list[BoxList] or dict[Tensor]): the output from the model. |
|
During training, it returns a dict[Tensor] which contains the losses. |
|
During testing, it returns list[BoxList] contains additional fields |
|
like `scores`, `labels` and `mask` (for Mask R-CNN models). |
|
""" |
|
if self.training and targets is None: |
|
raise ValueError("In training mode, targets should be passed") |
|
if self.training: |
|
assert targets is not None |
|
for target in targets: |
|
boxes = target["boxes"] |
|
if isinstance(boxes, torch.Tensor): |
|
if len(boxes.shape) != 2 or boxes.shape[-1] != 4: |
|
raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.") |
|
else: |
|
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") |
|
|
|
original_image_sizes: List[Tuple[int, int]] = [] |
|
for img in images: |
|
val = img[0].shape[-2:] |
|
assert len(val) == 2 |
|
original_image_sizes.append((val[0], val[1])) |
|
images1 = [img[0] for img in images] |
|
images2 = [img[1] for img in images] |
|
targets2 = copy.deepcopy(targets) |
|
|
|
|
|
images1, targets = self.transform(images1, targets) |
|
images2, targets2 = self.transform(images2, targets2) |
|
|
|
|
|
|
|
if targets is not None: |
|
for target_idx, target in enumerate(targets): |
|
boxes = target["boxes"] |
|
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] |
|
if degenerate_boxes.any(): |
|
|
|
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] |
|
degen_bb: List[float] = boxes[bb_idx].tolist() |
|
raise ValueError( |
|
"All bounding boxes should have positive height and width." |
|
f" Found invalid box {degen_bb} for target at index {target_idx}." |
|
) |
|
|
|
features1 = self.backbone1(images1.tensors) |
|
features2 = self.backbone2(images2.tensors) |
|
|
|
if isinstance(features1, torch.Tensor): |
|
features1 = OrderedDict([("0", features1)]) |
|
if isinstance(features2, torch.Tensor): |
|
features2 = OrderedDict([("0", features2)]) |
|
proposals, proposal_losses = self.rpn(images1, features1, targets) |
|
features = {0:features1,1:features2} |
|
detections, detector_losses = self.roi_heads(features, proposals, images1.image_sizes, targets) |
|
detections = self.transform.postprocess(detections, images1.image_sizes, original_image_sizes) |
|
|
|
losses = {} |
|
losses.update(detector_losses) |
|
losses.update(proposal_losses) |
|
|
|
if torch.jit.is_scripting(): |
|
if not self._has_warned: |
|
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting") |
|
self._has_warned = True |
|
return losses, detections |
|
else: |
|
return self.eager_outputs(losses, detections) |
|
|