Spaces:
Runtime error
Runtime error
from PIL import Image | |
# import requests | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
from torch.utils.data import DataLoader, Dataset | |
from pathlib import Path | |
from tqdm import tqdm | |
import json | |
import argparse | |
import numpy as np | |
class COCODataset(Dataset): | |
def __init__(self, | |
coco_root="/nas-ssd/jmincho/datasets/COCO/", | |
gen_caption_path=None, | |
is_gt=True): | |
super().__init__() | |
self.coco_root = Path(coco_root) | |
self.image_dir = self.coco_root.joinpath('images/val2014') | |
if is_gt: | |
print("Loading karpathy splits") | |
data_info_path = self.coco_root.joinpath('dataset_coco.json') | |
with open(data_info_path) as f: | |
karpathy_data = json.load(f) | |
data = [] | |
for datum in karpathy_data['images']: | |
# karpathy test split | |
if datum['split'] == 'test': | |
img_id = datum['filename'].split('.')[0] | |
new_datum = { | |
'img_id': img_id, | |
'captions': [d['raw'].strip() for d in datum['sentences']], | |
} | |
data.append(new_datum) | |
else: | |
print("Loading generated captions") | |
gen_caption_path = Path(gen_caption_path) | |
with open(gen_caption_path) as f: | |
# karpathy_data = json.load(f) | |
imgTogen_results = json.load(f)['imgToEval'] | |
data = [] | |
for img_id, img_data in imgTogen_results.items(): | |
new_datum = { | |
'img_id': img_id, | |
'captions': [img_data['caption']], | |
} | |
data.append(new_datum) | |
self.data = data | |
print('# images:', len(self.data)) | |
self.img_transform = processor.feature_extractor | |
self.tokenizer = processor.tokenizer | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
datum = self.data[idx] | |
img_id = datum['img_id'] | |
if 'COCO' not in img_id: | |
img_id = f'COCO_val2014_{str(img_id).zfill(12)}' | |
img_fname = f"{img_id}.jpg" | |
# COCO_val2014_000000522418.jpg | |
img_path = self.image_dir.joinpath(img_fname) | |
img = Image.open(img_path).convert("RGB") | |
# take first caption | |
caption = datum['captions'][0] | |
return { | |
"img": img, | |
"caption": caption, | |
} | |
def collate_fn(self, datum_list): | |
B = len(datum_list) | |
imgs = [datum['img'] for datum in datum_list] | |
images = self.img_transform(imgs, return_tensors="pt") | |
captions = [datum['caption'] for datum in datum_list] | |
text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True) | |
batch = { | |
'images': images, | |
'captions': text_tokens, | |
} | |
return batch | |
def compute_similarity(image_features, text_features, bs = 1000): | |
# compute similarity | |
max_pairs = image_features.shape[0] | |
similarity_scores = torch.zeros(max_pairs, max_pairs) | |
for v in range(0, max_pairs, bs): | |
for t in range(0, max_pairs, bs): | |
# print('Processing Visual '+str(v)+' Text '+str(t), end='\r') | |
batch_visual_emb = image_features[v:v+bs] | |
batch_caption_emb = text_features[t:t+bs] | |
logits = batch_visual_emb @ batch_caption_emb.t() | |
similarity_scores[v:v+bs,t:t+bs] = logits | |
print('Done similarity') | |
return similarity_scores | |
def compute_retrieval(a2b_sims, return_ranks=True): | |
""" | |
Args: | |
a2b_sims: Result of computing similarity between two sets of embeddings (emb1 @ emb2.T) | |
with shape (num_datapoints, num_datapoints). | |
Returns: | |
Retrieval metrics for that similarity. | |
""" | |
npts = a2b_sims.shape[0] | |
ranks = np.zeros(npts) | |
top1 = np.zeros(npts) | |
# loop source embedding indices | |
for index in range(npts): | |
# get order of similarities to target embeddings | |
inds = np.argsort(a2b_sims[index])[::-1] | |
# find where the correct embedding is ranked | |
where = np.where(inds == index) | |
rank = where[0][0] | |
ranks[index] = rank | |
# save the top1 result as well | |
top1[index] = inds[0] | |
# Compute metrics | |
r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) | |
r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) | |
r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) | |
r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks) | |
medr = np.floor(np.median(ranks)) + 1 | |
meanr = ranks.mean() + 1 | |
report_dict = {"r1": r1, "r5": r5, "r10": r10, "r50": r50, "medr": medr, "meanr": meanr, "sum": r1 + r5 + r10} | |
if return_ranks: | |
return report_dict, (ranks, top1) | |
else: | |
return report_dict | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--coco_root', type=str, default="/nas-ssd/jmincho/datasets/COCO/") | |
parser.add_argument('--gt', action='store_true') | |
parser.add_argument('--gen_caption_path', type=str, default="./eval_results/clipRN50_cider_test.json") | |
args = parser.parse_args() | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
device = "cuda" | |
model = model.to(device) | |
model.eval() | |
print(f"Loaded CLIP at {device}") | |
batch_size = 1000 | |
dataset = COCODataset( | |
coco_root="/nas-ssd/jmincho/datasets/COCO/", | |
gen_caption_path=args.gen_caption_path, | |
is_gt=args.gt | |
) | |
data_loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
collate_fn=dataset.collate_fn, | |
shuffle=False, | |
num_workers=8) | |
# fwd all samples | |
image_features = [] | |
text_features = [] | |
for batch_idx, batch in enumerate(tqdm(data_loader)): | |
# print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r") | |
# images, texts = batch | |
with torch.no_grad(): | |
images = batch["images"].to(device) | |
texts = batch["captions"].to(device) | |
vision_outputs = model.vision_model(**batch['images']) | |
text_outputs = model.text_model(**batch['captions']) | |
image_embeds = vision_outputs[1] | |
image_embeds = model.visual_projection(image_embeds) | |
text_embeds = text_outputs[1] | |
text_embeds = model.text_projection(text_embeds) | |
# normalized features | |
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) | |
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) | |
text_features.append(text_embeds.detach().cpu()) | |
image_features.append(image_embeds.detach().cpu()) | |
image_features = torch.cat(image_features, 0) | |
text_features = torch.cat(text_features, 0) | |
print('Done forward') | |
# normalized features | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# if not single_caption: | |
# for cap_idx in range(text_features.shape[1]): | |
# similarity_scores = compute_similarity(image_features, text_features[:,cap_idx,:]) | |
# i2t_dict = compute_retrieval(similarity_scores.numpy()) | |
# t2i_dict = compute_retrieval(similarity_scores.t().numpy()) | |
# print(cap_idx, 'i2t', i2t_dict) | |
# print(cap_idx, 't2i', t2i_dict) | |
# else: | |
similarity_scores = compute_similarity(image_features, text_features) | |
i2t_dict = compute_retrieval(similarity_scores.numpy()) | |
t2i_dict = compute_retrieval(similarity_scores.t().numpy()) | |
print('i2t', i2t_dict) | |
print('t2i', t2i_dict) | |