yunusserhat's picture
Upload 40 files
94f372a verified
import torch.hub
from transformers import (
CLIPVisionModel,
CLIPVisionConfig,
CLIPModel,
CLIPProcessor,
AutoTokenizer,
CLIPTextModelWithProjection,
CLIPTextConfig,
CLIPVisionModelWithProjection,
ResNetModel,
ResNetConfig
)
from torch import nn
from PIL import Image
import requests
class CLIP(nn.Module):
def __init__(self, path):
"""Initializes the CLIP model."""
super().__init__()
if path == "":
config_vision = CLIPVisionConfig()
self.clip = CLIPVisionModel(config_vision)
else:
self.clip = CLIPVisionModel.from_pretrained(path)
def forward(self, x):
"""Predicts CLIP features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.clip(pixel_values=x["img"])["last_hidden_state"]
return features
class CLIPJZ(nn.Module):
def __init__(self, path):
"""Initializes the CLIP model."""
super().__init__()
if path == "":
config_vision = CLIPVisionConfig()
self.clip = CLIPVisionModel(config_vision)
else:
self.clip = CLIPVisionModel.from_pretrained(path)
def forward(self, x):
"""Predicts CLIP features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.clip(pixel_values=x["img"])["last_hidden_state"]
return features
class StreetCLIP(nn.Module):
def __init__(self, path):
"""Initializes the CLIP model."""
super().__init__()
self.clip = CLIPModel.from_pretrained(path)
self.transform = CLIPProcessor.from_pretrained(path)
def forward(self, x):
"""Predicts CLIP features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.clip.get_image_features(
**self.transform(images=x["img"], return_tensors="pt").to(x["gps"].device)
).unsqueeze(1)
return features
class CLIPText(nn.Module):
def __init__(self, path):
"""Initializes the CLIP model."""
super().__init__()
if path == "":
config_vision = CLIPVisionConfig()
self.clip = CLIPVisionModel(config_vision)
else:
self.clip = CLIPVisionModelWithProjection.from_pretrained(path)
def forward(self, x):
"""Predicts CLIP features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.clip(pixel_values=x["img"])
return features.image_embeds, features.last_hidden_state
class TextEncoder(nn.Module):
def __init__(self, path):
"""Initializes the CLIP text model."""
super().__init__()
if path == "":
config_vision = CLIPTextConfig()
self.clip = CLIPTextModelWithProjection(config_vision)
self.transform = AutoTokenizer()
else:
self.clip = CLIPTextModelWithProjection.from_pretrained(path)
self.transform = AutoTokenizer.from_pretrained(path)
for p in self.clip.parameters():
p.requires_grad = False
self.clip.eval()
def forward(self, x):
"""Predicts CLIP features from text.
Args:
x (dict that contains "text": list): Input batch
"""
features = self.clip(
**self.transform(x["text"], padding=True, return_tensors="pt").to(
x["gps"].device
)
).text_embeds
return features
class DINOv2(nn.Module):
def __init__(self, tag) -> None:
"""Initializes the DINO model."""
super().__init__()
self.dino = torch.hub.load("facebookresearch/dinov2", tag)
self.stride = 14 # ugly but dinov2 stride = 14
def forward(self, x):
"""Predicts DINO features from an image."""
x = x["img"]
# crop for stride
_, _, H, W = x.shape
H_new = H - H % self.stride
W_new = W - W % self.stride
x = x[:, :, :H_new, :W_new]
# forward features
x = self.dino.forward_features(x)
x = x["x_prenorm"]
return x
class ResNet(nn.Module):
def __init__(self, path):
"""Initializes the ResNet model."""
super().__init__()
if path == "":
config_vision = ResNetConfig()
self.resnet = ResNetModel(config_vision)
else:
self.resnet = ResNetModel.from_pretrained(path)
def forward(self, x):
"""Predicts ResNet50 features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.resnet(x["img"])["pooler_output"]
return features.squeeze()