|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
models and functions for building student and teacher networks for multi-granular losses. |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
|
|
import src.vision_transformer as vits |
|
from src.vision_transformer import trunc_normal_ |
|
|
|
|
|
class Instance_Superivsion_Head(nn.Module): |
|
""" |
|
a class to implement Instance Superivsion Head |
|
--in_dim: input dimension of projection head |
|
--hidden_dim: hidden dimension of projection head |
|
--out_dim: ouput dimension of projection and prediction heads |
|
--pred_hidden_dim: hidden dimension of prediction head |
|
--nlayers: layer number of projection head. prediction head has nlayers-1 layer |
|
--proj_bn: whether we use batch normalization in projection head |
|
--pred_bn: whether we use batch normalization in prediction head |
|
--norm_before_pred: whether we use normalization before prediction head |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_dim, |
|
hidden_dim=2048, |
|
out_dim=256, |
|
pred_hidden_dim=4096, |
|
nlayers=3, |
|
proj_bn=False, |
|
pred_bn=False, |
|
norm_before_pred=True, |
|
): |
|
super().__init__() |
|
nlayers = max(nlayers, 1) |
|
self.norm_before_pred = norm_before_pred |
|
|
|
self.projector = self._build_mlp( |
|
nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn |
|
) |
|
|
|
self.apply(self._init_weights) |
|
|
|
self.predictor = None |
|
if pred_hidden_dim > 0: |
|
self.predictor = self._build_mlp( |
|
nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn |
|
) |
|
|
|
def _init_weights(self, m): |
|
""" |
|
initilize the parameters in network |
|
""" |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False): |
|
""" |
|
build a mlp |
|
""" |
|
mlp = [] |
|
for layer in range(num_layers): |
|
dim1 = input_dim if layer == 0 else hidden_dim |
|
dim2 = output_dim if layer == num_layers - 1 else hidden_dim |
|
|
|
mlp.append(nn.Linear(dim1, dim2, bias=False)) |
|
|
|
if layer < num_layers - 1: |
|
if use_bn: |
|
mlp.append(nn.BatchNorm1d(dim2)) |
|
mlp.append(nn.GELU()) |
|
|
|
return nn.Sequential(*mlp) |
|
|
|
def forward(self, x, return_target=False): |
|
""" |
|
forward the input through projection head for teacher and |
|
projection/prediction heads for student |
|
""" |
|
feat = self.projector(x) |
|
|
|
if return_target: |
|
feat = nn.functional.normalize(feat, dim=-1, p=2) |
|
return feat |
|
|
|
if self.norm_before_pred: |
|
feat = nn.functional.normalize(feat, dim=-1, p=2) |
|
pred = self.predictor(feat) |
|
pred = nn.functional.normalize(pred, dim=-1, p=2) |
|
return pred |
|
|
|
|
|
class Local_Group_Superivsion_Head(nn.Module): |
|
""" |
|
a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head |
|
--in_dim: input dimension of projection head |
|
--hidden_dim: hidden dimension of projection head |
|
--out_dim: ouput dimension of projection and prediction heads |
|
--pred_hidden_dim: hidden dimension of prediction head |
|
--nlayers: layer number of projection head. prediction head has nlayers-1 layer |
|
--proj_bn: whether we use batch normalization in projection head |
|
--pred_bn: whether we use batch normalization in prediction head |
|
--norm_before_pred: whether we use normalization before prediction head |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_dim, |
|
hidden_dim=2048, |
|
out_dim=256, |
|
pred_hidden_dim=4096, |
|
nlayers=3, |
|
proj_bn=False, |
|
pred_bn=False, |
|
norm_before_pred=True, |
|
): |
|
super().__init__() |
|
nlayers = max(nlayers, 1) |
|
self.norm_before_pred = norm_before_pred |
|
|
|
self.projector = self._build_mlp( |
|
nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn |
|
) |
|
|
|
self.apply(self._init_weights) |
|
|
|
self.predictor = None |
|
if pred_hidden_dim > 0: |
|
self.predictor = self._build_mlp( |
|
nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn |
|
) |
|
|
|
def _init_weights(self, m): |
|
""" |
|
initilize the parameters in network |
|
""" |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False): |
|
""" |
|
build a mlp |
|
""" |
|
mlp = [] |
|
for layer in range(num_layers): |
|
dim1 = input_dim if layer == 0 else hidden_dim |
|
dim2 = output_dim if layer == num_layers - 1 else hidden_dim |
|
|
|
mlp.append(nn.Linear(dim1, dim2, bias=False)) |
|
|
|
if layer < num_layers - 1: |
|
if use_bn: |
|
mlp.append(nn.BatchNorm1d(dim2)) |
|
mlp.append(nn.GELU()) |
|
|
|
return nn.Sequential(*mlp) |
|
|
|
def forward(self, x, return_target=False): |
|
""" |
|
forward the input through projection head for teacher and |
|
projection/prediction heads for student |
|
""" |
|
feat = self.projector(x) |
|
|
|
if return_target: |
|
feat = nn.functional.normalize(feat, dim=-1, p=2) |
|
return feat |
|
|
|
if self.norm_before_pred: |
|
feat = nn.functional.normalize(feat, dim=-1, p=2) |
|
pred = self.predictor(feat) |
|
pred = nn.functional.normalize(pred, dim=-1, p=2) |
|
return pred |
|
|
|
|
|
class Group_Superivsion_Head(nn.Module): |
|
""" |
|
a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head |
|
--in_dim: input dimension of projection head |
|
--hidden_dim: hidden dimension of projection head |
|
--out_dim: ouput dimension of projection and prediction heads |
|
--pred_hidden_dim: hidden dimension of prediction head |
|
--nlayers: layer number of projection head. prediction head has nlayers-1 layer |
|
--proj_bn: whether we use batch normalization in projection head |
|
--pred_bn: whether we use batch normalization in prediction head |
|
--norm_before_pred: whether we use normalization before prediction head |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_dim, |
|
out_dim, |
|
hidden_dim=2048, |
|
bottleneck_dim=256, |
|
nlayers=3, |
|
use_bn=False, |
|
norm_last_layer=True, |
|
): |
|
super().__init__() |
|
nlayers = max(nlayers, 1) |
|
|
|
self.projector = self._build_mlp( |
|
nlayers, in_dim, hidden_dim, bottleneck_dim, use_bn=use_bn |
|
) |
|
self.apply(self._init_weights) |
|
|
|
self.last_layer = nn.utils.weight_norm( |
|
nn.Linear(bottleneck_dim, out_dim, bias=False) |
|
) |
|
self.last_layer.weight_g.data.fill_(1) |
|
if norm_last_layer: |
|
self.last_layer.weight_g.requires_grad = False |
|
|
|
def _build_mlp(self, num_layers, in_dim, hidden_dim, output_dim, use_bn=False): |
|
""" |
|
build a mlp |
|
""" |
|
if num_layers == 1: |
|
mlp = nn.Linear(in_dim, output_dim) |
|
else: |
|
layers = [nn.Linear(in_dim, hidden_dim)] |
|
if use_bn: |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
layers.append(nn.GELU()) |
|
for _ in range(num_layers - 2): |
|
layers.append(nn.Linear(hidden_dim, hidden_dim)) |
|
if use_bn: |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
layers.append(nn.GELU()) |
|
layers.append(nn.Linear(hidden_dim, output_dim)) |
|
mlp = nn.Sequential(*layers) |
|
return mlp |
|
|
|
def _init_weights(self, m): |
|
""" |
|
initilize the parameters in network |
|
""" |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
""" |
|
forward the input through the projection and last prediction layer |
|
""" |
|
feat = self.projector(x) |
|
feat = nn.functional.normalize(feat, dim=-1, p=2) |
|
feat = self.last_layer(feat) |
|
return feat |
|
|
|
|
|
class Block_mem(nn.Module): |
|
""" |
|
a class to implement a memory block for local group supervision |
|
--dim: feature vector dimenstion in the memory |
|
--K: memory size |
|
--top_n: number for neighbors in local group supervision |
|
""" |
|
|
|
def __init__(self, dim, K=2048, top_n=10): |
|
super().__init__() |
|
self.dim = dim |
|
self.K = K |
|
self.top_n = top_n |
|
|
|
self.register_buffer("queue_q", torch.randn(K, dim)) |
|
self.register_buffer("queue_k", torch.randn(K, dim)) |
|
self.register_buffer("queue_v", torch.randn(K, dim)) |
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
|
|
|
@torch.no_grad() |
|
def _dequeue_and_enqueue(self, query, weak_aug_flags): |
|
""" |
|
update memory queue |
|
""" |
|
|
|
|
|
len_weak = 0 |
|
query = concat_all_gather(query) |
|
if weak_aug_flags is not None: |
|
weak_aug_flags = weak_aug_flags.cuda() |
|
weak_aug_flags = concat_all_gather(weak_aug_flags) |
|
idx_weak = torch.nonzero(weak_aug_flags) |
|
len_weak = len(idx_weak) |
|
if len_weak > 0: |
|
idx_weak = idx_weak.squeeze(-1) |
|
query = query[idx_weak] |
|
else: |
|
return len_weak |
|
|
|
all_size = query.shape[0] |
|
ptr = int(self.queue_ptr) |
|
remaining_size = ptr + all_size - self.K |
|
if remaining_size <= 0: |
|
self.queue_q[ptr : ptr + all_size, :] = query |
|
self.queue_k[ptr : ptr + all_size, :] = query |
|
self.queue_v[ptr : ptr + all_size, :] = query |
|
ptr = ptr + all_size |
|
self.queue_ptr[0] = (ptr + all_size) % self.K |
|
else: |
|
self.queue_q[ptr : self.K, :] = query[0 : self.K - ptr, :] |
|
self.queue_k[ptr : self.K, :] = query[0 : self.K - ptr, :] |
|
self.queue_v[ptr : self.K, :] = query[0 : self.K - ptr, :] |
|
|
|
self.queue_q[0:remaining_size, :] = query[self.K - ptr :, :] |
|
self.queue_k[0:remaining_size, :] = query[self.K - ptr :, :] |
|
self.queue_v[0:remaining_size, :] = query[self.K - ptr :, :] |
|
self.queue_ptr[0] = remaining_size |
|
return len_weak |
|
|
|
@torch.no_grad() |
|
def _get_similarity_index(self, x): |
|
""" |
|
compute the index of the top-n neighbors (key-value pair) in memory |
|
""" |
|
x = nn.functional.normalize(x, dim=-1) |
|
queue_q = nn.functional.normalize(self.queue_q, dim=-1) |
|
|
|
cosine = x @ queue_q.T |
|
_, index = torch.topk(cosine, self.top_n, dim=-1) |
|
return index |
|
|
|
@torch.no_grad() |
|
def _get_similarity_samples(self, query, index=None): |
|
""" |
|
compute top-n neighbors (key-value pair) in memory |
|
""" |
|
if index is None: |
|
index = self._get_similarity_index(query) |
|
get_k = self.queue_k[index.view(-1)] |
|
get_v = self.queue_v[index.view(-1)] |
|
B, tn = index.shape |
|
get_k = get_k.view(B, tn, self.dim) |
|
get_v = get_v.view(B, tn, self.dim) |
|
return get_k, get_v |
|
|
|
def forward(self, query): |
|
""" |
|
forward to find the top-n neighbors (key-value pair) in memory |
|
""" |
|
get_k, get_v = self._get_similarity_samples(query) |
|
return get_k, get_v |
|
|
|
|
|
class vit_mem(nn.Module): |
|
""" |
|
a class to implement a memory for local group supervision |
|
--dim: feature vector dimenstion in the memory |
|
--K: memory size |
|
--top_n: number for neighbors in local group supervision |
|
""" |
|
|
|
def __init__(self, dim, K=2048, top_n=10): |
|
super().__init__() |
|
self.block = Block_mem(dim, K, top_n) |
|
|
|
def _dequeue_and_enqueue(self, query, weak_aug_flags): |
|
""" |
|
update memory queue |
|
""" |
|
query = query.float() |
|
weak_num = self.block._dequeue_and_enqueue(query, weak_aug_flags) |
|
return weak_num |
|
|
|
def forward(self, query): |
|
""" |
|
forward to find the top-n neighbors (key-value pair) in memory |
|
""" |
|
query = query.float() |
|
get_k, get_v = self.block(query) |
|
return get_k, get_v |
|
|
|
|
|
class Mugs_Wrapper(nn.Module): |
|
""" |
|
a class to implement a student or teacher wrapper for mugs |
|
--backbone: the backnone of student/teacher, e.g. ViT-small |
|
--instance_head: head, including projection/prediction heads, for instance supervision |
|
--local_group_head: head, including projection/prediction heads, for local group supervision |
|
--group_head: projection head for group supervision |
|
""" |
|
|
|
def __init__(self, backbone, instance_head, local_group_head, group_head): |
|
super(Mugs_Wrapper, self).__init__() |
|
backbone.fc, backbone.head = nn.Identity(), nn.Identity() |
|
self.backbone = backbone |
|
self.instance_head = instance_head |
|
self.local_group_head = local_group_head |
|
self.group_head = group_head |
|
|
|
def forward(self, x, return_target=False, local_group_memory_inputs=None): |
|
""" |
|
forward input to get instance/local-group/group targets or predictions |
|
""" |
|
|
|
if not isinstance(x, list): |
|
x = [x] |
|
idx_crops = torch.cumsum( |
|
torch.unique_consecutive( |
|
torch.tensor([inp.shape[-1] for inp in x]), |
|
return_counts=True, |
|
)[1], |
|
0, |
|
) |
|
|
|
start_idx = 0 |
|
class_tokens = torch.empty(0).to(x[0].device) |
|
mean_patch_tokens = torch.empty(0).to(x[0].device) |
|
memory_class_tokens = torch.empty(0).to(x[0].device) |
|
for _, end_idx in enumerate(idx_crops): |
|
input = torch.cat(x[start_idx:end_idx]) |
|
token_feat, memory_class_token_feat = self.backbone( |
|
input, |
|
return_all=True, |
|
local_group_memory_inputs=local_group_memory_inputs, |
|
) |
|
|
|
|
|
class_token_feat = token_feat[ |
|
:, 0 |
|
] |
|
class_tokens = torch.cat((class_tokens, class_token_feat)) |
|
|
|
start_idx = end_idx |
|
|
|
if self.local_group_head is not None: |
|
memory_class_tokens = torch.cat( |
|
(memory_class_tokens, memory_class_token_feat) |
|
) |
|
if input.shape[-1] == 224: |
|
mean_patch_tokens = torch.cat( |
|
(mean_patch_tokens, token_feat[:, 1:].mean(dim=1)) |
|
) |
|
|
|
|
|
instance_feat = ( |
|
self.instance_head(class_tokens, return_target) |
|
if self.instance_head is not None |
|
else None |
|
) |
|
|
|
|
|
local_group_feat = ( |
|
self.local_group_head(memory_class_tokens, return_target) |
|
if self.local_group_head is not None |
|
else None |
|
) |
|
|
|
|
|
group_feat = ( |
|
self.group_head(class_tokens) if self.group_head is not None else None |
|
) |
|
return instance_feat, local_group_feat, group_feat, mean_patch_tokens.detach() |
|
|
|
|
|
def get_model(args): |
|
""" |
|
build a student or teacher for mugs, includeing backbone, instance/local-group/group heads, |
|
and memory buffer |
|
""" |
|
|
|
if args.arch in vits.__dict__.keys(): |
|
student = vits.__dict__[args.arch]( |
|
patch_size=args.patch_size, |
|
num_relation_blocks=1, |
|
drop_path_rate=args.drop_path_rate, |
|
) |
|
teacher = vits.__dict__[args.arch]( |
|
patch_size=args.patch_size, num_relation_blocks=1 |
|
) |
|
embed_dim = student.embed_dim |
|
else: |
|
assert f"Unknow architecture: {args.arch}" |
|
|
|
|
|
student_mem = vit_mem( |
|
embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n |
|
) |
|
teacher_mem = vit_mem( |
|
embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n |
|
) |
|
|
|
|
|
student_instance_head, student_local_group_head, student_group_head = ( |
|
None, |
|
None, |
|
None, |
|
) |
|
teacher_instance_head, teacher_local_group_head, teacher_group_head = ( |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
if args.loss_weights[0] > 0: |
|
student_instance_head = Instance_Superivsion_Head( |
|
in_dim=embed_dim, |
|
hidden_dim=2048, |
|
out_dim=args.instance_out_dim, |
|
pred_hidden_dim=4096, |
|
nlayers=3, |
|
proj_bn=args.use_bn_in_head, |
|
pred_bn=False, |
|
norm_before_pred=args.norm_before_pred, |
|
) |
|
teacher_instance_head = Instance_Superivsion_Head( |
|
in_dim=embed_dim, |
|
hidden_dim=2048, |
|
out_dim=args.instance_out_dim, |
|
pred_hidden_dim=0, |
|
nlayers=3, |
|
proj_bn=args.use_bn_in_head, |
|
pred_bn=False, |
|
norm_before_pred=args.norm_before_pred, |
|
) |
|
|
|
|
|
if args.loss_weights[1] > 0: |
|
student_local_group_head = Local_Group_Superivsion_Head( |
|
in_dim=embed_dim, |
|
hidden_dim=2048, |
|
out_dim=args.local_group_out_dim, |
|
pred_hidden_dim=4096, |
|
nlayers=3, |
|
proj_bn=args.use_bn_in_head, |
|
pred_bn=False, |
|
norm_before_pred=args.norm_before_pred, |
|
) |
|
teacher_local_group_head = Local_Group_Superivsion_Head( |
|
in_dim=embed_dim, |
|
hidden_dim=2048, |
|
out_dim=args.local_group_out_dim, |
|
pred_hidden_dim=0, |
|
nlayers=3, |
|
proj_bn=args.use_bn_in_head, |
|
pred_bn=False, |
|
norm_before_pred=args.norm_before_pred, |
|
) |
|
|
|
|
|
if args.loss_weights[2] > 0: |
|
student_group_head = Group_Superivsion_Head( |
|
in_dim=embed_dim, |
|
out_dim=args.group_out_dim, |
|
hidden_dim=2048, |
|
bottleneck_dim=args.group_bottleneck_dim, |
|
nlayers=3, |
|
use_bn=args.use_bn_in_head, |
|
norm_last_layer=args.norm_last_layer, |
|
) |
|
teacher_group_head = Group_Superivsion_Head( |
|
in_dim=embed_dim, |
|
out_dim=args.group_out_dim, |
|
hidden_dim=2048, |
|
bottleneck_dim=args.group_bottleneck_dim, |
|
nlayers=3, |
|
use_bn=args.use_bn_in_head, |
|
norm_last_layer=args.norm_last_layer, |
|
) |
|
|
|
|
|
student = Mugs_Wrapper( |
|
student, student_instance_head, student_local_group_head, student_group_head |
|
) |
|
|
|
teacher = Mugs_Wrapper( |
|
teacher, teacher_instance_head, teacher_local_group_head, teacher_group_head |
|
) |
|
|
|
return student, teacher, student_mem, teacher_mem |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def concat_all_gather(tensor): |
|
""" |
|
Performs all_gather operation on the provided tensors. |
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
|
""" |
|
|
|
tensors_gather = [ |
|
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) |
|
] |
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|
|
|
output = torch.cat(tensors_gather, dim=0) |
|
return output |
|
|