Spaces:
Running
Running
File size: 5,095 Bytes
7dbe662 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import PIL
from PIL.Image import Image
from typing import Union
from sklearn.decomposition import PCA
import torch
from torch import nn
from torchvision import transforms as tfs
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DINO_MODEL_HUB = 'facebookresearch/dino:main'
DINO_MODEL_TYPE = ['dino_vits16',
'dino_vits8',
'dino_vitb16',
'dino_vitb8',
'dino_xcit_small_12_p16',
'dino_xcit_small_12_p8',
'dino_xcit_medium_24_p16',
'dino_xcit_medium_24_p8',
'dino_resnet50']
DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main'
DINOV2_MODEL_TYPE = ['dinov2_vits14',
'dinov2_vitb14',
'dinov2_vitl14',
'dinov2_vitg14']
class DINO(nn.Module):
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
super(DINO, self).__init__()
assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device)
self.device = device
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
self.img_size = img_size
self.pca_dim = pca_dim
self.pca = self.set_pca(pca_dim) if pca_dim else None
def set_pca(self, dim=64):
return PCA(n_components=dim)
@torch.no_grad()
def extract_features(
self, img: Union[Image, torch.Tensor], transform=True, size=None
):
if transform and isinstance(img, Image):
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
with torch.no_grad():
out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0]
out = out[:, 1:, :] # we discard the [CLS] token
h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int(
img.shape[3] / self.model.patch_embed.patch_size
)
dim = out.shape[-1]
out = out.reshape(-1, h, w, dim)
dtype = out.dtype
if size is not None:
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
if self.pca:
B, H, W, C = out.shape
out = out.view(-1, C).cpu().numpy()
out = self.pca.fit_transform(out)
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
return out
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
return self.extract_features(img, transform, size)
@staticmethod
def transform(img, image_size):
transforms = tfs.Compose(
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
)
img = transforms(img)
return img
class DINOV2(nn.Module):
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
super(DINOV2, self).__init__()
assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device)
self.device = device
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
self.img_size = img_size
self.pca_dim = pca_dim
self.pca = self.set_pca(pca_dim) if pca_dim else None
def set_pca(self, dim=64):
return PCA(n_components=dim)
@torch.no_grad()
def extract_features(
self, img: Union[Image, torch.Tensor], transform=True, size=None
):
if transform and isinstance(img, Image):
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
with torch.no_grad():
out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens']
h, w = int(img.shape[2] / self.model.patch_size), int(
img.shape[3] / self.model.patch_size
)
dim = out.shape[-1]
out = out.reshape(-1, h, w, dim)
dtype = out.dtype
if size is not None:
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
if self.pca:
B, H, W, C = out.shape
out = out.view(-1, C).cpu().numpy()
out = self.pca.fit_transform(out)
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
return out
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
return self.extract_features(img, transform, size)
@staticmethod
def transform(img, image_size):
transforms = tfs.Compose(
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
)
img = transforms(img)
return img
|