# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from torch.nn.init import trunc_normal_ from torch.nn.utils import weight_norm class DINOHead(nn.Module): def __init__( self, in_dim, out_dim, use_bn=False, nlayers=3, hidden_dim=2048, bottleneck_dim=256, mlp_bias=True, ): super().__init__() nlayers = max(nlayers, 1) self.mlp = _build_mlp( nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias, ) self.apply(self._init_weights) self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) self.last_layer.weight_g.data.fill_(1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.mlp(x) eps = 1e-6 if x.dtype == torch.float16 else 1e-12 x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) x = self.last_layer(x) return x def _build_mlp( nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True ): if nlayers == 1: return nn.Linear(in_dim, bottleneck_dim, bias=bias) else: layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) return nn.Sequential(*layers)