|
from __future__ import annotations |
|
|
|
import clip |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
|
|
class ClipSimilarity(nn.Module): |
|
def __init__(self, name: str = "ViT-L/14"): |
|
super().__init__() |
|
assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") |
|
self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224) |
|
|
|
self.model, _ = clip.load(name, device="cpu", download_root="./") |
|
self.model.eval().requires_grad_(False) |
|
|
|
self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073))) |
|
self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711))) |
|
|
|
def encode_text(self, text: list[str]) -> torch.Tensor: |
|
text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device) |
|
text_features = self.model.encode_text(text) |
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
return text_features |
|
|
|
def encode_image(self, image: torch.Tensor) -> torch.Tensor: |
|
image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False) |
|
image = image - rearrange(self.mean, "c -> 1 c 1 1") |
|
image = image / rearrange(self.std, "c -> 1 c 1 1") |
|
image_features = self.model.encode_image(image) |
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
return image_features |
|
|
|
def forward( |
|
self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str] |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
image_features_0 = self.encode_image(image_0) |
|
image_features_1 = self.encode_image(image_1) |
|
text_features_0 = self.encode_text(text_0) |
|
text_features_1 = self.encode_text(text_1) |
|
sim_0 = F.cosine_similarity(image_features_0, text_features_0) |
|
sim_1 = F.cosine_similarity(image_features_1, text_features_1) |
|
sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0) |
|
sim_image = F.cosine_similarity(image_features_0, image_features_1) |
|
return sim_0, sim_1, sim_direction, sim_image |
|
|