Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
""" | |
This file provides the definition of the convolutional heads used to predict masks, as well as the losses | |
""" | |
import io | |
from collections import defaultdict | |
from typing import List, Optional | |
import pdb | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from PIL import Image | |
from . import box_ops | |
from .misc import NestedTensor, interpolate, nested_tensor_from_tensor_list | |
try: | |
from panopticapi.utils import id2rgb, rgb2id | |
except ImportError: | |
pass | |
class DETRsegm(nn.Module): | |
def __init__(self, detr, freeze_detr=False): | |
super().__init__() | |
self.detr = detr | |
if freeze_detr: | |
for p in self.parameters(): | |
p.requires_grad_(False) | |
hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead | |
self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) | |
self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) | |
def forward(self, samples: NestedTensor): | |
if isinstance(samples, (list, torch.Tensor)): | |
samples = nested_tensor_from_tensor_list(samples) | |
features, pos = self.detr.backbone(samples) | |
bs = features[-1].tensors.shape[0] | |
src, mask = features[-1].decompose() | |
assert mask is not None | |
src_proj = self.detr.input_proj(src) | |
hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) | |
outputs_class = self.detr.class_embed(hs) | |
outputs_coord = self.detr.bbox_embed(hs).sigmoid() | |
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} | |
if self.detr.aux_loss: | |
out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord) | |
# FIXME h_boxes takes the last one computed, keep this in mind | |
bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) | |
seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) | |
outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) | |
out["pred_masks"] = outputs_seg_masks | |
return out | |
def _expand(tensor, length: int): | |
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) | |
class MaskHeadSmallConv(nn.Module): | |
""" | |
Simple convolutional head, using group norm. | |
Upsampling is done using a FPN approach | |
""" | |
def __init__(self, dim, fpn_dims, context_dim, nheads = 8): | |
super().__init__() | |
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] | |
self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) | |
self.gn1 = torch.nn.GroupNorm(nheads, dim) | |
self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) | |
self.gn2 = torch.nn.GroupNorm(nheads, inter_dims[1]) | |
self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) | |
self.gn3 = torch.nn.GroupNorm(nheads, inter_dims[2]) | |
self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) | |
self.gn4 = torch.nn.GroupNorm(nheads, inter_dims[3]) | |
self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) | |
self.gn5 = torch.nn.GroupNorm(nheads, inter_dims[4]) | |
self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) | |
self.dim = dim | |
self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) | |
self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) | |
self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_uniform_(m.weight, a=1) | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]): | |
x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) | |
x = self.lay1(x) | |
x = self.gn1(x) | |
x = F.relu(x) | |
x = self.lay2(x) | |
x = self.gn2(x) | |
x = F.relu(x) | |
cur_fpn = self.adapter1(fpns[0]) | |
if cur_fpn.size(0) != x.size(0): | |
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) | |
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") | |
x = self.lay3(x) | |
x = self.gn3(x) | |
x = F.relu(x) | |
cur_fpn = self.adapter2(fpns[1]) | |
if cur_fpn.size(0) != x.size(0): | |
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) | |
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") | |
x = self.lay4(x) | |
x = self.gn4(x) | |
x = F.relu(x) | |
cur_fpn = self.adapter3(fpns[2]) | |
if cur_fpn.size(0) != x.size(0): | |
cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) | |
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") | |
x = self.lay5(x) | |
x = self.gn5(x) | |
x = F.relu(x) | |
x = self.out_lay(x) | |
return x | |
class MHAttentionMap(nn.Module): | |
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" | |
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): | |
super().__init__() | |
self.num_heads = num_heads | |
self.hidden_dim = hidden_dim | |
self.dropout = nn.Dropout(dropout) | |
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) | |
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) | |
nn.init.zeros_(self.k_linear.bias) | |
nn.init.zeros_(self.q_linear.bias) | |
nn.init.xavier_uniform_(self.k_linear.weight) | |
nn.init.xavier_uniform_(self.q_linear.weight) | |
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 | |
def forward(self, q, k, mask: Optional[Tensor] = None): | |
q = self.q_linear(q) | |
k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) | |
qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) | |
kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) | |
weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) | |
if mask is not None: | |
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) | |
weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size()) | |
weights = self.dropout(weights) | |
return weights | |
def dice_loss(inputs, targets, num_boxes): | |
""" | |
Compute the DICE loss, similar to generalized IOU for masks | |
Args: | |
inputs: A float tensor of arbitrary shape. | |
The predictions for each example. | |
targets: A float tensor with the same shape as inputs. Stores the binary | |
classification label for each element in inputs | |
(0 for the negative class and 1 for the positive class). | |
""" | |
inputs = inputs.sigmoid() | |
inputs = inputs.flatten(1) | |
numerator = 2 * (inputs * targets).sum(1) | |
denominator = inputs.sum(-1) + targets.sum(-1) | |
loss = 1 - (numerator + 1) / (denominator + 1) | |
return loss.sum() / num_boxes | |
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): | |
""" | |
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
Args: | |
inputs: A float tensor of arbitrary shape. | |
The predictions for each example. | |
targets: A float tensor with the same shape as inputs. Stores the binary | |
classification label for each element in inputs | |
(0 for the negative class and 1 for the positive class). | |
alpha: (optional) Weighting factor in range (0,1) to balance | |
positive vs negative examples. Default = -1 (no weighting). | |
gamma: Exponent of the modulating factor (1 - p_t) to | |
balance easy vs hard examples. | |
Returns: | |
Loss tensor | |
""" | |
prob = inputs.sigmoid() | |
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
p_t = prob * targets + (1 - prob) * (1 - targets) | |
loss = ce_loss * ((1 - p_t) ** gamma) | |
#pdb.set_trace() | |
if alpha >= 0: | |
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
loss = alpha_t * loss | |
return loss.mean(1).sum() / num_boxes | |
class PostProcessSegm(nn.Module): | |
def __init__(self, threshold=0.5): | |
super().__init__() | |
self.threshold = threshold | |
def forward(self, results, outputs, orig_target_sizes, max_target_sizes): | |
assert len(orig_target_sizes) == len(max_target_sizes) | |
max_h, max_w = max_target_sizes.max(0)[0].tolist() | |
outputs_masks = outputs["pred_masks"].squeeze(2) | |
outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) | |
outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() | |
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): | |
img_h, img_w = t[0], t[1] | |
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) | |
results[i]["masks"] = F.interpolate( | |
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" | |
).byte() | |
return results | |
class PostProcessPanoptic(nn.Module): | |
"""This class converts the output of the model to the final panoptic result, in the format expected by the | |
coco panoptic API """ | |
def __init__(self, is_thing_map, threshold=0.85): | |
""" | |
Parameters: | |
is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether | |
the class is a thing (True) or a stuff (False) class | |
threshold: confidence threshold: segments with confidence lower than this will be deleted | |
""" | |
super().__init__() | |
self.threshold = threshold | |
self.is_thing_map = is_thing_map | |
def forward(self, outputs, processed_sizes, target_sizes=None): | |
""" This function computes the panoptic prediction from the model's predictions. | |
Parameters: | |
outputs: This is a dict coming directly from the model. See the model doc for the content. | |
processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the | |
model, ie the size after data augmentation but before batching. | |
target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size | |
of each prediction. If left to None, it will default to the processed_sizes | |
""" | |
if target_sizes is None: | |
target_sizes = processed_sizes | |
assert len(processed_sizes) == len(target_sizes) | |
out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] | |
assert len(out_logits) == len(raw_masks) == len(target_sizes) | |
preds = [] | |
def to_tuple(tup): | |
if isinstance(tup, tuple): | |
return tup | |
return tuple(tup.cpu().tolist()) | |
for cur_logits, cur_masks, cur_boxes, size, target_size in zip( | |
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes | |
): | |
# we filter empty queries and detection below threshold | |
scores, labels = cur_logits.softmax(-1).max(-1) | |
keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) | |
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) | |
cur_scores = cur_scores[keep] | |
cur_classes = cur_classes[keep] | |
cur_masks = cur_masks[keep] | |
cur_masks = interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) | |
cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) | |
h, w = cur_masks.shape[-2:] | |
assert len(cur_boxes) == len(cur_classes) | |
# It may be that we have several predicted masks for the same stuff class. | |
# In the following, we track the list of masks ids for each stuff class (they are merged later on) | |
cur_masks = cur_masks.flatten(1) | |
stuff_equiv_classes = defaultdict(lambda: []) | |
for k, label in enumerate(cur_classes): | |
if not self.is_thing_map[label.item()]: | |
stuff_equiv_classes[label.item()].append(k) | |
def get_ids_area(masks, scores, dedup=False): | |
# This helper function creates the final panoptic segmentation image | |
# It also returns the area of the masks that appears on the image | |
m_id = masks.transpose(0, 1).softmax(-1) | |
if m_id.shape[-1] == 0: | |
# We didn't detect any mask :( | |
m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) | |
else: | |
m_id = m_id.argmax(-1).view(h, w) | |
if dedup: | |
# Merge the masks corresponding to the same stuff class | |
for equiv in stuff_equiv_classes.values(): | |
if len(equiv) > 1: | |
for eq_id in equiv: | |
m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) | |
final_h, final_w = to_tuple(target_size) | |
seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) | |
seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) | |
np_seg_img = ( | |
torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() | |
) | |
m_id = torch.from_numpy(rgb2id(np_seg_img)) | |
area = [] | |
for i in range(len(scores)): | |
area.append(m_id.eq(i).sum().item()) | |
return area, seg_img | |
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) | |
if cur_classes.numel() > 0: | |
# We know filter empty masks as long as we find some | |
while True: | |
filtered_small = torch.as_tensor( | |
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device | |
) | |
if filtered_small.any().item(): | |
cur_scores = cur_scores[~filtered_small] | |
cur_classes = cur_classes[~filtered_small] | |
cur_masks = cur_masks[~filtered_small] | |
area, seg_img = get_ids_area(cur_masks, cur_scores) | |
else: | |
break | |
else: | |
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) | |
segments_info = [] | |
for i, a in enumerate(area): | |
cat = cur_classes[i].item() | |
segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) | |
del cur_classes | |
with io.BytesIO() as out: | |
seg_img.save(out, format="PNG") | |
predictions = {"png_string": out.getvalue(), "segments_info": segments_info} | |
preds.append(predictions) | |
return preds | |