Spaces:
Build error
Build error
import argparse | |
import requests | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from pathlib import Path | |
from torchvision import transforms | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from timm.data import create_transform | |
from config import get_config | |
from model import build_model | |
# Download human-readable labels for ImageNet. | |
response = requests.get("https://git.io/JJkYN") | |
labels = response.text.split("\n") | |
def parse_option(): | |
parser = argparse.ArgumentParser('UniCL demo script', add_help=False) | |
parser.add_argument('--cfg', type=str, default="configs/unicl_swin_base.yaml", metavar="FILE", help='path to config file', ) | |
args, unparsed = parser.parse_known_args() | |
config = get_config(args) | |
return args, config | |
def build_transforms(img_size, center_crop=True): | |
t = [transforms.ToPILImage()] | |
if center_crop: | |
size = int((256 / 224) * img_size) | |
t.append( | |
transforms.Resize(size) | |
) | |
t.append( | |
transforms.CenterCrop(img_size) | |
) | |
else: | |
t.append( | |
transforms.Resize(img_size) | |
) | |
t.append(transforms.ToTensor()) | |
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) | |
return transforms.Compose(t) | |
def build_transforms4display(img_size, center_crop=True): | |
t = [transforms.ToPILImage()] | |
if center_crop: | |
size = int((256 / 224) * img_size) | |
t.append( | |
transforms.Resize(size) | |
) | |
t.append( | |
transforms.CenterCrop(img_size) | |
) | |
else: | |
t.append( | |
transforms.Resize(img_size) | |
) | |
t.append(transforms.ToTensor()) | |
return transforms.Compose(t) | |
args, config = parse_option() | |
''' | |
build model | |
''' | |
model = build_model(config) | |
url = './in21k_yfcc14m_gcc15m_swin_base.pth' | |
checkpoint = torch.load(url, map_location="cpu") | |
model.load_state_dict(checkpoint["model"]) | |
model.eval() | |
''' | |
build data transform | |
''' | |
eval_transforms = build_transforms(224, center_crop=True) | |
display_transforms = build_transforms4display(224, center_crop=True) | |
''' | |
build upsampler | |
''' | |
# upsampler = nn.Upsample(scale_factor=16, mode='bilinear') | |
''' | |
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py | |
''' | |
def show_cam_on_image(img: np.ndarray, | |
mask: np.ndarray, | |
use_rgb: bool = False, | |
colormap: int = cv2.COLORMAP_JET) -> np.ndarray: | |
""" This function overlays the cam mask on the image as an heatmap. | |
By default the heatmap is in BGR format. | |
:param img: The base image in RGB or BGR format. | |
:param mask: The cam mask. | |
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
:param colormap: The OpenCV colormap to be used. | |
:returns: The default image with the cam overlay. | |
""" | |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
if use_rgb: | |
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
heatmap = np.float32(heatmap) / 255 | |
if np.max(img) > 1: | |
raise Exception( | |
"The input image should np.float32 in the range [0, 1]") | |
cam = 0.7*heatmap + 0.3*img | |
# cam = cam / np.max(cam) | |
return np.uint8(255 * cam) | |
def recognize_image(image, texts): | |
img_t = eval_transforms(image) | |
img_d = display_transforms(image).permute(1, 2, 0).numpy() | |
text_embeddings = model.get_text_embeddings(texts.split(';')) | |
# compute output | |
feat_img, feat_map, H, W = model.encode_image(img_t.unsqueeze(0), output_map=True) | |
output = model.logit_scale.exp() * feat_img @ text_embeddings.t() | |
prediction = output.softmax(-1).flatten() | |
# generate feat map given the top matched texts | |
output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1) | |
output_map = output_map.view(1, 1, H, W) | |
output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map) | |
output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy() | |
output_map = (output_map - output_map.min()) / (output_map.max() - output_map.min()) | |
heatmap = show_cam_on_image(img_d, output_map, use_rgb=True) | |
show_img = np.concatenate((np.uint8(255 * img_d), heatmap), 1) | |
return {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}, Image.fromarray(show_img) | |
image = gr.inputs.Image() | |
label = gr.outputs.Label(num_top_classes=100) | |
description = "UniCL for Zero-shot Image Recognition. Given an image, our model maps it to an arbitary text in a candidate pool." | |
gr.Interface( | |
description=description, | |
fn=recognize_image, | |
inputs=["image", "text"], | |
outputs=[ | |
label, | |
gr.outputs.Image( | |
type="pil", | |
label="crop input/heat map"), | |
], | |
examples=[ | |
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"], | |
["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple with a write note 'iphone'"], | |
["./apple_with_ipod.jpg", "a write note 'ipod'; a write note 'ipad'; a write note 'iphone'"], | |
["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"], | |
["./donut.png", "a bread; a donut; some donuts"], | |
["./horse.png", "an image of horse; an image of cow; an image of dog"], | |
["./dog_and_cat.jfif", "a dog; a cat; dog and cat"], | |
], | |
article=Path("docs/intro.md").read_text() | |
).launch() | |