Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.config import configurable | |
from detectron2.layers import Linear, ShapeSpec | |
class ZeroShotClassifier(nn.Module): | |
def __init__( | |
self, | |
input_shape: ShapeSpec, | |
*, | |
num_classes: int, | |
zs_weight_path: str, | |
zs_weight_dim: int = 512, | |
use_bias: float = 0.0, | |
norm_weight: bool = True, | |
norm_temperature: float = 50.0, | |
): | |
super().__init__() | |
if isinstance(input_shape, int): # some backward compatibility | |
input_shape = ShapeSpec(channels=input_shape) | |
input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) | |
self.norm_weight = norm_weight | |
self.norm_temperature = norm_temperature | |
self.use_bias = use_bias < 0 | |
if self.use_bias: | |
self.cls_bias = nn.Parameter(torch.ones(1) * use_bias) | |
self.linear = nn.Linear(input_size, zs_weight_dim) | |
if zs_weight_path == 'rand': | |
zs_weight = torch.randn((zs_weight_dim, num_classes)) | |
nn.init.normal_(zs_weight, std=0.01) | |
else: | |
zs_weight = torch.tensor( | |
np.load(zs_weight_path), | |
dtype=torch.float32).permute(1, 0).contiguous() # D x C | |
zs_weight = torch.cat( | |
[zs_weight, zs_weight.new_zeros((zs_weight_dim, 1))], | |
dim=1) # D x (C + 1) | |
if self.norm_weight: | |
zs_weight = F.normalize(zs_weight, p=2, dim=0) | |
if zs_weight_path == 'rand': | |
self.zs_weight = nn.Parameter(zs_weight) | |
else: | |
self.register_buffer('zs_weight', zs_weight) | |
assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape | |
def from_config(cls, cfg, input_shape): | |
return { | |
'input_shape': input_shape, | |
'num_classes': cfg.MODEL.ROI_HEADS.NUM_CLASSES, | |
'zs_weight_path': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH, | |
'zs_weight_dim': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM, | |
'use_bias': cfg.MODEL.ROI_BOX_HEAD.USE_BIAS, | |
'norm_weight': cfg.MODEL.ROI_BOX_HEAD.NORM_WEIGHT, | |
'norm_temperature': cfg.MODEL.ROI_BOX_HEAD.NORM_TEMP, | |
} | |
def forward(self, x, classifier=None): | |
''' | |
Inputs: | |
x: B x D' | |
classifier_info: (C', C' x D) | |
''' | |
x = self.linear(x) | |
if classifier is not None: | |
zs_weight = classifier.permute(1, 0).contiguous() # D x C' | |
zs_weight = F.normalize(zs_weight, p=2, dim=0) \ | |
if self.norm_weight else zs_weight | |
else: | |
zs_weight = self.zs_weight | |
if self.norm_weight: | |
x = self.norm_temperature * F.normalize(x, p=2, dim=1) | |
x = torch.mm(x, zs_weight) | |
if self.use_bias: | |
x = x + self.cls_bias | |
return x |