Pranjal2041's picture
Initial demo
970a7a2
from typing import List, OrderedDict, Tuple
import warnings
import numpy as np
import pandas as pd
import cv2
import os
from torch.nn.modules.conv import Conv2d
from torch.utils.data.dataset import ConcatDataset
from tqdm import tqdm
import argparse
from torch.utils.data import Dataset,DataLoader
import torch
import torch.nn as nn
from torchvision import models
import detection.transforms as transforms
import torchvision.transforms as T
import detection.utils as utils
import torch.nn.functional as F
import shutil
import json
from detection.engine import train_one_epoch, evaluate
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch.multiprocessing
import copy
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.roi_heads import RoIHeads
# First we will create the FRCNN model
def get_FRCNN_model(num_classes=1):
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True,trainable_backbone_layers=3,min_size=1800,max_size=3600,image_std=(1.0,1.0,1.0),box_score_thresh=0.001)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)
return model
# Some utility heads for Bilateral Model
class RoIpool(nn.Module):
def __init__(self,pool):
super().__init__()
self.box_roi_pool1 = copy.deepcopy(pool)
self.box_roi_pool2 = copy.deepcopy(pool)
def forward(self,features,proposals,image_shapes):
x = self.box_roi_pool1(features[0],proposals,image_shapes)
y = self.box_roi_pool2(features[1],proposals,image_shapes)
z = torch.cat((x,y),dim=1)
return z
class TwoMLPHead(nn.Module):
"""
Standard heads for FPN-based models
Args:
in_channels (int): number of input channels
representation_size (int): size of the intermediate representation
"""
def __init__(self, in_channels=None, representation_size=None):
super().__init__()
self.fc6 = nn.Linear(in_channels, representation_size)
self.fc7 = nn.Linear(representation_size, representation_size)
def forward(self, x):
x = x.flatten(start_dim=1)
x = F.relu(self.fc6(x))
x = F.relu(self.fc7(x))
return x
# Next the bilateral model
class Bilateral_model(nn.Module):
def __init__(self,frcnn_model):
super().__init__()
self.frcnn = frcnn_model
self.transform = copy.deepcopy(frcnn_model.transform)
self.backbone1 = copy.deepcopy(frcnn_model.backbone)
self.backbone2 = copy.deepcopy(frcnn_model.backbone)
self.rpn = copy.deepcopy(frcnn_model.rpn)
for param in self.rpn.parameters():
param.requires_grad = False
for param in self.backbone1.parameters():
param.requires_grad = False
for param in self.backbone2.parameters():
param.requires_grad = False
box_roi_pool = RoIpool(frcnn_model.roi_heads.box_roi_pool)
box_head = TwoMLPHead(512*7*7,1024)
box_predictor = copy.deepcopy(frcnn_model.roi_heads.box_predictor)
box_score_thresh=0.001
box_nms_thresh=0.5
box_detections_per_img=100
box_fg_iou_thresh=0.5
box_bg_iou_thresh=0.5
box_batch_size_per_image=512
box_positive_fraction=0.25
bbox_reg_weights=None
self.roi_heads = RoIHeads(
# Box
box_roi_pool,
box_head,
box_predictor,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
box_score_thresh,
box_nms_thresh,
box_detections_per_img,
)
@torch.jit.unused
def eager_outputs(self, losses, detections):
if self.training:
return losses
return detections
def forward(self, images, targets=None):
"""
Args:
images (list[Tensor(tuples)]): images to be processed
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
if self.training:
assert targets is not None
for target in targets:
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.")
else:
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img[0].shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
images1 = [img[0] for img in images]
images2 = [img[1] for img in images]
targets2 = copy.deepcopy(targets)
#print(images1.shape)
#print(images2.shape)
images1, targets = self.transform(images1, targets)
images2, targets2 = self.transform(images2, targets2)
# Check for degenerate boxes
# TODO: Move this to a function
if targets is not None:
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError(
"All bounding boxes should have positive height and width."
f" Found invalid box {degen_bb} for target at index {target_idx}."
)
features1 = self.backbone1(images1.tensors)
features2 = self.backbone2(images2.tensors)
#print(self.backbone1.out_channels)
if isinstance(features1, torch.Tensor):
features1 = OrderedDict([("0", features1)])
if isinstance(features2, torch.Tensor):
features2 = OrderedDict([("0", features2)])
proposals, proposal_losses = self.rpn(images1, features1, targets)
features = {0:features1,1:features2}
detections, detector_losses = self.roi_heads(features, proposals, images1.image_sizes, targets)
detections = self.transform.postprocess(detections, images1.image_sizes, original_image_sizes) # type: ignore[operator]
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
self._has_warned = True
return losses, detections
else:
return self.eager_outputs(losses, detections)