from transformers import CLIPProcessor class ClipTransform(object): def __init__(self, split): self.transform = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") def __call__(self, x): # return self.transform(images=x, return_tensors="pt")["pixel_values"].squeeze(0) return self.transform(images=[x], return_tensors="pt") if __name__ == "__main__": # sanity check import glob import torchvision.transforms as transforms from torchvision.utils import save_image from omegaconf import DictConfig, OmegaConf from hydra.utils import instantiate import torch from PIL import Image fast_clip_config = OmegaConf.load( "./configs/dataset/train_transform/fast_clip.yaml" ) fast_clip_transform = instantiate(fast_clip_config) clip_transform = ClipTransform(None) img_paths = glob.glob("./datasets/osv5m/test/images/*.jpg") original_imgs, re_implemted_imgs, diff = [], [], [] for i in range(16): img = Image.open(img_paths[i]) clip_img = clip_transform(img) fast_clip_img = fast_clip_transform(img) original_imgs.append(clip_img) re_implemted_imgs.append(fast_clip_img) max_diff = (clip_img - fast_clip_img).abs() diff.append(max_diff) if max_diff.max() > 1e-5: print(max_diff.max()) original_imgs = torch.stack(original_imgs) re_implemted_imgs = torch.stack(re_implemted_imgs) diff = torch.stack(diff)