Spaces:
Runtime error
Runtime error
from typing import List, Optional, Tuple | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from . import axis_ops, ilnr_loss | |
from .vnl_loss import VNL_Loss | |
from .midas_loss import MidasLoss | |
from .detr.detr import MLP | |
from .detr.transformer import Transformer | |
from .detr.backbone import Backbone, Joiner | |
from .detr.position_encoding import PositionEmbeddingSine | |
from .detr.misc import nested_tensor_from_tensor_list, interpolate | |
from .detr import box_ops | |
from .detr.segmentation import ( | |
MHAttentionMap, MaskHeadSmallConv, dice_loss, sigmoid_focal_loss | |
) | |
class INTR(torch.nn.Module): | |
""" | |
Implement Interaction 3D Transformer. | |
""" | |
def __init__( | |
self, | |
backbone_name = 'resnet50', | |
image_size = [192, 256], | |
ignore_index = -100, | |
num_classes = 1, | |
num_queries = 15, | |
freeze_backbone = False, | |
transformer_hidden_dim = 256, | |
transformer_dropout = 0.1, | |
transformer_nhead = 8, | |
transformer_dim_feedforward = 2048, | |
transformer_num_encoder_layers = 6, | |
transformer_num_decoder_layers = 6, | |
transformer_normalize_before = False, | |
transformer_return_intermediate_dec = True, | |
layers_movable = 3, | |
layers_rigid = 3, | |
layers_kinematic = 3, | |
layers_action = 3, | |
layers_axis = 2, | |
layers_affordance = 3, | |
affordance_focal_alpha = 0.95, | |
axis_bins = 30, | |
depth_on = True, | |
): | |
""" Initializes the model. | |
Parameters: | |
backbone: torch module of the backbone to be used. See backbone.py | |
transformer: torch module of the transformer architecture. See transformer.py | |
num_classes: number of object classes | |
num_queries: number of object queries, ie detection slot. This is the maximal number of objects | |
DETR can detect in a single image. For COCO, we recommend 100 queries. | |
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. | |
""" | |
super().__init__() | |
self._ignore_index = ignore_index | |
self._image_size = image_size | |
self._axis_bins = axis_bins | |
self._affordance_focal_alpha = affordance_focal_alpha | |
# backbone | |
backbone_base = Backbone(backbone_name, not freeze_backbone, True, False) | |
N_steps = transformer_hidden_dim // 2 | |
position_embedding = PositionEmbeddingSine(N_steps, normalize=True) | |
backbone = Joiner(backbone_base, position_embedding) | |
backbone.num_channels = backbone_base.num_channels | |
self.backbone = backbone | |
self.transformer = Transformer( | |
d_model=transformer_hidden_dim, | |
dropout=transformer_dropout, | |
nhead=transformer_nhead, | |
dim_feedforward=transformer_dim_feedforward, | |
num_encoder_layers=transformer_num_encoder_layers, | |
num_decoder_layers=transformer_num_decoder_layers, | |
normalize_before=transformer_normalize_before, | |
return_intermediate_dec=transformer_return_intermediate_dec, | |
) | |
hidden_dim = self.transformer.d_model | |
self.hidden_dim = hidden_dim | |
nheads = self.transformer.nhead | |
self.num_queries = num_queries | |
# before transformer, input_proj maps 2048 channel resnet50 output to 512-channel | |
# transformer input | |
self.input_proj = nn.Conv2d(self.backbone.num_channels, hidden_dim, kernel_size=1) | |
# query mlp maps 2d keypoint coordinates to 256-dim positional encoding | |
self.query_mlp = MLP(2, hidden_dim, hidden_dim, 2) | |
# bbox MLP | |
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) | |
if layers_movable > 1: | |
self.movable_embed = MLP(hidden_dim, hidden_dim, 3, layers_movable) | |
elif layers_movable == 1: | |
self.movable_embed = nn.Linear(hidden_dim, 3) | |
else: | |
raise ValueError("not supported") | |
if layers_rigid > 1: | |
self.rigid_embed = MLP(hidden_dim, hidden_dim, 2, layers_rigid) | |
elif layers_rigid == 1: | |
#self.rigid_embed = nn.Linear(hidden_dim, 2) | |
self.rigid_embed = nn.Linear(hidden_dim, 3) | |
else: | |
raise ValueError("not supported") | |
if layers_kinematic > 1: | |
self.kinematic_embed = MLP(hidden_dim, hidden_dim, 3, layers_kinematic) | |
elif layers_kinematic == 1: | |
self.kinematic_embed = nn.Linear(hidden_dim, 3) | |
else: | |
raise ValueError("not supported") | |
if layers_action > 1: | |
self.action_embed = MLP(hidden_dim, hidden_dim, 3, layers_action) | |
elif layers_action == 1: | |
self.action_embed = nn.Linear(hidden_dim, 3) | |
else: | |
raise ValueError("not supported") | |
if layers_axis > 1: | |
#self.axis_embed = MLP(hidden_dim, hidden_dim, 4, layers_axis) | |
self.axis_embed = MLP(hidden_dim, hidden_dim, 3, layers_axis) | |
# classification | |
# self.axis_embed = MLP(hidden_dim, hidden_dim, self._axis_bins * 2, layers_axis) | |
elif layers_axis == 1: | |
self.axis_embed = nn.Linear(hidden_dim, 3) | |
else: | |
raise ValueError("not supported") | |
# affordance | |
if layers_affordance > 1: | |
self.aff_embed = MLP(hidden_dim, hidden_dim, 2, layers_affordance) | |
elif layers_affordance == 1: | |
self.aff_embed = nn.Linear(hidden_dim, 2) | |
else: | |
raise ValueError("not supported") | |
# affordance head | |
self.aff_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) | |
self.aff_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads) | |
# mask head | |
self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) | |
self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads) | |
# depth head | |
self._depth_on = depth_on | |
if self._depth_on: | |
self.depth_query = nn.Embedding(1, hidden_dim) | |
self.depth_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) | |
self.depth_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads) | |
self.depth_loss = ilnr_loss.MEADSTD_TANH_NORM_Loss() | |
fov = torch.tensor(1.0) | |
focal_length = (image_size[1] / 2 / torch.tan(fov / 2)).item() | |
self.vnl_loss = VNL_Loss(focal_length, focal_length, image_size) | |
self.midas_loss = MidasLoss(alpha=0.1) | |
def freeze_layers(self, names): | |
""" | |
Freeze layers in 'names'. | |
""" | |
for name, param in self.named_parameters(): | |
for freeze_name in names: | |
if freeze_name in name: | |
#print(name + ' ' + freeze_name) | |
param.requires_grad = False | |
def forward( | |
self, | |
image: torch.Tensor, | |
valid: torch.Tensor, | |
keypoints: torch.Tensor, | |
bbox: torch.Tensor, | |
masks: torch.Tensor, | |
movable: torch.Tensor, | |
rigid: torch.Tensor, | |
kinematic: torch.Tensor, | |
action: torch.Tensor, | |
affordance: torch.Tensor, | |
affordance_map: torch.FloatTensor, | |
depth: torch.Tensor, | |
axis: torch.Tensor, | |
fov: torch.Tensor, | |
backward: bool = True, | |
**kwargs, | |
): | |
""" | |
Model forward. Set backward = False if the model is inference only. | |
""" | |
device = image.device | |
# number of queries can be different in runtime | |
num_queries = keypoints.shape[1] | |
# DETR forward | |
samples = image | |
if isinstance(samples, (list, torch.Tensor)): | |
samples = nested_tensor_from_tensor_list(samples) | |
features, pos = self.backbone(samples) | |
bs = features[-1].tensors.shape[0] | |
src, mask = features[-1].decompose() | |
assert mask is not None | |
# sample keypoint queries from the positional embedding | |
use_sine = False | |
if use_sine: | |
anchors = keypoints.float() | |
anchors_float = anchors.clone() | |
anchors_float = anchors_float.reshape(-1, 2) | |
anchors_float[:, 0] = ((anchors_float[:, 0] / self._image_size[1]) - 0.5) * 2 | |
anchors_float[:, 1] = ((anchors_float[:, 1] / self._image_size[0]) - 0.5) * 2 | |
anchors_float = anchors_float.unsqueeze(1).unsqueeze(1) | |
# 4x256x1x1 | |
keypoint_queries = F.grid_sample( | |
#pos[0].repeat(self.num_queries, 1, 1, 1), | |
pos[-1].repeat(self.num_queries, 1, 1, 1), | |
anchors_float, | |
mode='nearest', | |
align_corners=True | |
) | |
# 4 x 10 (number of object queires) x 256 | |
keypoint_queries = keypoint_queries.squeeze().reshape(-1, self.num_queries, self.hidden_dim) | |
else: | |
# use learned MLP to map postional encoding | |
anchors = keypoints.float() | |
anchors_float = anchors.clone() | |
anchors_float[:, :, 0] = ((anchors_float[:, :, 0] / self._image_size[1]) - 0.5) * 2 | |
anchors_float[:, :, 1] = ((anchors_float[:, :, 1] / self._image_size[0]) - 0.5) * 2 | |
keypoint_queries = self.query_mlp(anchors_float) | |
# append depth_query if the model is learning depth. | |
if self._depth_on: | |
bs = keypoint_queries.shape[0] | |
depth_query = self.depth_query.weight.unsqueeze(0).repeat(bs, 1, 1) | |
keypoint_queries = torch.cat((keypoint_queries, depth_query), dim=1) | |
# transformer forward | |
src_proj = self.input_proj(src) | |
hs, memory = self.transformer(src_proj, mask, keypoint_queries, pos[-1]) | |
if self._depth_on: | |
depth_hs = hs[-1][:, -1:] | |
ord_hs = hs[-1][:, :-1] | |
else: | |
ord_hs = hs[-1] | |
outputs_coord = self.bbox_embed(ord_hs).sigmoid() | |
outputs_movable = self.movable_embed(ord_hs) | |
outputs_rigid = self.rigid_embed(ord_hs) | |
outputs_kinematic = self.kinematic_embed(ord_hs) | |
outputs_action = self.action_embed(ord_hs) | |
# axis forward | |
outputs_axis = self.axis_embed(ord_hs).sigmoid() | |
# sigmoid range is 0 to 1, we want it to be -1 to 1 | |
outputs_axis = (outputs_axis - 0.5) * 2 | |
# affordance forward | |
bbox_aff = self.aff_attention(ord_hs, memory, mask=mask) | |
aff_masks = self.aff_head(src_proj, bbox_aff, [features[2].tensors, features[1].tensors, features[0].tensors]) | |
outputs_aff_masks = aff_masks.view(bs, num_queries, aff_masks.shape[-2], aff_masks.shape[-1]) | |
# mask forward | |
bbox_mask = self.bbox_attention(ord_hs, memory, mask=mask) | |
seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) | |
outputs_seg_masks = seg_masks.view(bs, num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) | |
# depth forward | |
outputs_depth = None | |
if self._depth_on: | |
depth_att = self.depth_attention(depth_hs, memory, mask=mask) | |
depth_masks = self.depth_head( | |
src_proj, | |
depth_att, | |
[features[2].tensors, features[1].tensors, features[0].tensors] | |
) | |
outputs_depth = depth_masks.view(bs, 1, depth_masks.shape[-2], depth_masks.shape[-1]) | |
out = { | |
'pred_boxes': box_ops.box_cxcywh_to_xyxy(outputs_coord), | |
'pred_movable': outputs_movable, | |
'pred_rigid': outputs_rigid, | |
'pred_kinematic': outputs_kinematic, | |
'pred_action': outputs_action, | |
'pred_masks': outputs_seg_masks, | |
'pred_axis': outputs_axis, | |
'pred_depth': outputs_depth, | |
'pred_affordance': outputs_aff_masks, | |
} | |
if not backward: | |
return out | |
# backward | |
src_boxes = outputs_coord | |
target_boxes = bbox | |
target_boxes = box_ops.box_xyxy_to_cxcywh(target_boxes) | |
bbox_valid = bbox[:, :, 0] > -0.5 | |
num_boxes = bbox_valid.sum() | |
if num_boxes == 0: | |
out['loss_bbox'] = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_giou'] = torch.tensor(0.0, requires_grad=True).to(device) | |
else: | |
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') | |
loss_bbox = loss_bbox * bbox_valid.unsqueeze(2) # remove invalid | |
out['loss_bbox'] = loss_bbox.sum() / num_boxes | |
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_boxes).reshape(-1, 4), | |
box_ops.box_cxcywh_to_xyxy(target_boxes).reshape(-1, 4), | |
)).reshape(-1, self.num_queries) | |
loss_giou = loss_giou * bbox_valid # remove invalid | |
out['loss_giou'] = loss_giou.sum() / num_boxes | |
# affordance | |
affordance_valid = affordance[:, :, 0] > -0.5 | |
if affordance_valid.sum() == 0: | |
out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device) | |
else: | |
src_aff_masks = outputs_aff_masks[affordance_valid] | |
tgt_aff_masks = affordance_map[affordance_valid] | |
src_aff_masks = src_aff_masks.flatten(1) | |
tgt_aff_masks = tgt_aff_masks.flatten(1) | |
loss_aff = sigmoid_focal_loss( | |
src_aff_masks, | |
tgt_aff_masks, | |
affordance_valid.sum(), | |
alpha=self._affordance_focal_alpha, | |
) | |
out['loss_affordance'] = loss_aff | |
# axis | |
axis_valid = axis[:, :, 0] > 0.0 | |
num_axis = axis_valid.sum() | |
if num_axis == 0: | |
out['loss_axis_angle'] = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_axis_offset'] = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_eascore'] = torch.tensor(0.0, requires_grad=True).to(device) | |
else: | |
# regress angle | |
src_axis_angle = outputs_axis[axis_valid] | |
src_axis_angle_norm = F.normalize(src_axis_angle[:, :2]) | |
src_axis_angle = torch.cat((src_axis_angle_norm, src_axis_angle[:, 2:]), dim=-1) | |
target_axis_xyxy = axis[axis_valid] | |
axis_center = target_boxes[axis_valid].clone() | |
axis_center[:, 2:] = axis_center[:, :2] | |
target_axis_angle = axis_ops.line_xyxy_to_angle(target_axis_xyxy, center=axis_center) | |
loss_axis_angle = F.l1_loss(src_axis_angle[:, :2], target_axis_angle[:, :2], reduction='sum') / num_axis | |
loss_axis_offset = F.l1_loss(src_axis_angle[:, 2:], target_axis_angle[:, 2:], reduction='sum') / num_axis | |
out['loss_axis_angle'] = loss_axis_angle | |
out['loss_axis_offset'] = loss_axis_offset | |
src_axis_xyxy = axis_ops.line_angle_to_xyxy(src_axis_angle, center=axis_center) | |
target_axis_xyxy = axis_ops.line_angle_to_xyxy(target_axis_angle, center=axis_center) | |
axis_eascore, _, _ = axis_ops.ea_score(src_axis_xyxy, target_axis_xyxy) | |
loss_eascore = 1 - axis_eascore | |
out['loss_eascore'] = loss_eascore.mean() | |
loss_movable = F.cross_entropy(outputs_movable.permute(0, 2, 1), movable, ignore_index=self._ignore_index) | |
if torch.isnan(loss_movable): | |
loss_movable = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_movable'] = loss_movable | |
loss_rigid = F.cross_entropy(outputs_rigid.permute(0, 2, 1), rigid, ignore_index=self._ignore_index) | |
if torch.isnan(loss_rigid): | |
loss_rigid = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_rigid'] = loss_rigid | |
loss_kinematic = F.cross_entropy(outputs_kinematic.permute(0, 2, 1), kinematic, ignore_index=self._ignore_index) | |
if torch.isnan(loss_kinematic): | |
loss_kinematic = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_kinematic'] = loss_kinematic | |
loss_action = F.cross_entropy(outputs_action.permute(0, 2, 1), action, ignore_index=self._ignore_index) | |
if torch.isnan(loss_action): | |
loss_action = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_action'] = loss_action | |
# depth backward | |
if self._depth_on: | |
# (bs, 1, H, W) | |
src_depths = interpolate(outputs_depth, size=depth.shape[-2:], mode='bilinear', align_corners=False) | |
src_depths = src_depths.clamp(min=0.0, max=1.0) | |
tgt_depths = depth.unsqueeze(1) # (bs, H, W) | |
valid_depth = depth[:, 0, 0] > 0 | |
if valid_depth.any(): | |
src_depths = src_depths[valid_depth] | |
tgt_depths = tgt_depths[valid_depth] | |
depth_mask = tgt_depths > 1e-8 | |
midas_loss, ssi_loss, reg_loss = self.midas_loss(src_depths, tgt_depths, depth_mask) | |
loss_vnl = self.vnl_loss(tgt_depths, src_depths) | |
out['loss_depth'] = midas_loss | |
out['loss_vnl'] = loss_vnl | |
else: | |
out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device) | |
else: | |
out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device) | |
# mask backward | |
tgt_masks = masks | |
src_masks = interpolate(outputs_seg_masks, size=tgt_masks.shape[-2:], mode='bilinear', align_corners=False) | |
valid_mask = tgt_masks.sum(dim=-1).sum(dim=-1) > 10 | |
if valid_mask.sum() == 0: | |
out['loss_mask'] = torch.tensor(0.0, requires_grad=True).to(device) | |
out['loss_dice'] = torch.tensor(0.0, requires_grad=True).to(device) | |
else: | |
num_masks = valid_mask.sum() | |
src_masks = src_masks[valid_mask] | |
tgt_masks = tgt_masks[valid_mask] | |
src_masks = src_masks.flatten(1) | |
tgt_masks = tgt_masks.flatten(1) | |
tgt_masks = tgt_masks.view(src_masks.shape) | |
out['loss_mask'] = sigmoid_focal_loss(src_masks, tgt_masks.float(), num_masks) | |
out['loss_dice'] = dice_loss(src_masks, tgt_masks, num_masks) | |
return out | |