Spaces:
Running
Running
import torch.hub | |
from transformers import ( | |
CLIPVisionModel, | |
CLIPVisionConfig, | |
CLIPModel, | |
CLIPProcessor, | |
AutoTokenizer, | |
CLIPTextModelWithProjection, | |
CLIPTextConfig, | |
CLIPVisionModelWithProjection, | |
ResNetModel, | |
ResNetConfig | |
) | |
from torch import nn | |
from PIL import Image | |
import requests | |
class CLIP(nn.Module): | |
def __init__(self, path): | |
"""Initializes the CLIP model.""" | |
super().__init__() | |
if path == "": | |
config_vision = CLIPVisionConfig() | |
self.clip = CLIPVisionModel(config_vision) | |
else: | |
self.clip = CLIPVisionModel.from_pretrained(path) | |
def forward(self, x): | |
"""Predicts CLIP features from an image. | |
Args: | |
x (dict that contains "img": torch.Tensor): Input batch | |
""" | |
features = self.clip(pixel_values=x["img"])["last_hidden_state"] | |
return features | |
class CLIPJZ(nn.Module): | |
def __init__(self, path): | |
"""Initializes the CLIP model.""" | |
super().__init__() | |
if path == "": | |
config_vision = CLIPVisionConfig() | |
self.clip = CLIPVisionModel(config_vision) | |
else: | |
self.clip = CLIPVisionModel.from_pretrained(path) | |
def forward(self, x): | |
"""Predicts CLIP features from an image. | |
Args: | |
x (dict that contains "img": torch.Tensor): Input batch | |
""" | |
features = self.clip(pixel_values=x["img"])["last_hidden_state"] | |
return features | |
class StreetCLIP(nn.Module): | |
def __init__(self, path): | |
"""Initializes the CLIP model.""" | |
super().__init__() | |
self.clip = CLIPModel.from_pretrained(path) | |
self.transform = CLIPProcessor.from_pretrained(path) | |
def forward(self, x): | |
"""Predicts CLIP features from an image. | |
Args: | |
x (dict that contains "img": torch.Tensor): Input batch | |
""" | |
features = self.clip.get_image_features( | |
**self.transform(images=x["img"], return_tensors="pt").to(x["gps"].device) | |
).unsqueeze(1) | |
return features | |
class CLIPText(nn.Module): | |
def __init__(self, path): | |
"""Initializes the CLIP model.""" | |
super().__init__() | |
if path == "": | |
config_vision = CLIPVisionConfig() | |
self.clip = CLIPVisionModel(config_vision) | |
else: | |
self.clip = CLIPVisionModelWithProjection.from_pretrained(path) | |
def forward(self, x): | |
"""Predicts CLIP features from an image. | |
Args: | |
x (dict that contains "img": torch.Tensor): Input batch | |
""" | |
features = self.clip(pixel_values=x["img"]) | |
return features.image_embeds, features.last_hidden_state | |
class TextEncoder(nn.Module): | |
def __init__(self, path): | |
"""Initializes the CLIP text model.""" | |
super().__init__() | |
if path == "": | |
config_vision = CLIPTextConfig() | |
self.clip = CLIPTextModelWithProjection(config_vision) | |
self.transform = AutoTokenizer() | |
else: | |
self.clip = CLIPTextModelWithProjection.from_pretrained(path) | |
self.transform = AutoTokenizer.from_pretrained(path) | |
for p in self.clip.parameters(): | |
p.requires_grad = False | |
self.clip.eval() | |
def forward(self, x): | |
"""Predicts CLIP features from text. | |
Args: | |
x (dict that contains "text": list): Input batch | |
""" | |
features = self.clip( | |
**self.transform(x["text"], padding=True, return_tensors="pt").to( | |
x["gps"].device | |
) | |
).text_embeds | |
return features | |
class DINOv2(nn.Module): | |
def __init__(self, tag) -> None: | |
"""Initializes the DINO model.""" | |
super().__init__() | |
self.dino = torch.hub.load("facebookresearch/dinov2", tag) | |
self.stride = 14 # ugly but dinov2 stride = 14 | |
def forward(self, x): | |
"""Predicts DINO features from an image.""" | |
x = x["img"] | |
# crop for stride | |
_, _, H, W = x.shape | |
H_new = H - H % self.stride | |
W_new = W - W % self.stride | |
x = x[:, :, :H_new, :W_new] | |
# forward features | |
x = self.dino.forward_features(x) | |
x = x["x_prenorm"] | |
return x | |
class ResNet(nn.Module): | |
def __init__(self, path): | |
"""Initializes the ResNet model.""" | |
super().__init__() | |
if path == "": | |
config_vision = ResNetConfig() | |
self.resnet = ResNetModel(config_vision) | |
else: | |
self.resnet = ResNetModel.from_pretrained(path) | |
def forward(self, x): | |
"""Predicts ResNet50 features from an image. | |
Args: | |
x (dict that contains "img": torch.Tensor): Input batch | |
""" | |
features = self.resnet(x["img"])["pooler_output"] | |
return features.squeeze() | |