Spaces:
Build error
Build error
import clip | |
from PIL import Image | |
import pandas as pd | |
import torch | |
from dataloader.extract_features_dataloader import transform_resize, question_preprocess | |
from model.vqa_model import NetVQA | |
from dataclasses import dataclass | |
from torch.cuda.amp import autocast | |
import gradio as gr | |
class InferenceConfig: | |
''' | |
Describes configuration of the training process | |
''' | |
model: str = "RN50x64" | |
checkpoint_root_clip: str = "./checkpoints/clip" | |
checkpoint_root_head: str = "./checkpoints/head" | |
use_question_preprocess: bool = True # True: delete ? at end | |
aux_mapping = {0: "unanswerable", | |
1: "unsuitable", | |
2: "yes", | |
3: "no", | |
4: "number", | |
5: "color", | |
6: "other"} | |
folds = 10 | |
# Data | |
n_classes: int = 5726 | |
# class mapping | |
class_mapping: str = "./data/annotations/class_mapping.csv" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
config = InferenceConfig() | |
# load class mapping | |
cm = pd.read_csv(config.class_mapping) | |
classid_to_answer = {} | |
for i in range(len(cm)): | |
row = cm.iloc[i] | |
classid_to_answer[row["class_id"]] = row["answer"] | |
clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip, device=config.device) | |
model = NetVQA(config).to(config.device) | |
config.checkpoint_head = "{}/{}.pt".format(config.checkpoint_root_head, config.model) | |
model_state_dict = torch.load(config.checkpoint_head) | |
model.load_state_dict(model_state_dict, strict=True) | |
model.eval() | |
# Select Preprocessing | |
image_transforms = transform_resize(clip_model.visual.input_resolution) | |
if config.use_question_preprocess: | |
question_transforms = question_preprocess | |
else: | |
question_transforms = None | |
clip_model.eval() | |
def predict(img, text): | |
img = Image.fromarray(img) | |
img = image_transforms(img) | |
img = img.unsqueeze(dim=0) | |
if question_transforms is not None: | |
question = question_transforms(text) | |
else: | |
question = text | |
question_tokens = clip.tokenize(question, truncate=True) | |
with torch.no_grad(): | |
img = img.to(config.device) | |
img_feature = clip_model.encode_image(img) | |
question_tokens = question_tokens.to(config.device) | |
question_feature = clip_model.encode_text(question_tokens) | |
with autocast(): | |
output, output_aux = model(img_feature, question_feature) | |
prediction_vqa = dict() | |
output = output.cpu().squeeze(0) | |
for k, v in classid_to_answer.items(): | |
prediction_vqa[v] = float(output[k]) | |
prediction_aux = dict() | |
output_aux = output_aux.cpu().squeeze(0) | |
for k, v in config.aux_mapping.items(): | |
prediction_aux[v] = float(output_aux[k]) | |
return prediction_vqa, prediction_aux | |
description = """ | |
Less Is More: Linear Layers on CLIP Features as Powerful VizWiz Model | |
Our approach focuses on visual question answering for visual impaired people. We fine-tuned our approach on the <a href='https://vizwiz.org/tasks-and-datasets/vqa/' >CVPR Grand Challenge VizWiz 2022</a> data set. | |
You may click on one of the examples or upload your own image and question. The Gradio app shows the current answer for your question and an answer category. | |
Link to our <a href='https://arxiv.org/abs/2206.05281'>paper</a>. | |
""" | |
gr.Interface(fn=predict, | |
description=description, | |
inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')], | |
outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)], | |
examples=[['examples/Augustiner.jpg', 'What is this?'],['examples/VizWiz_test_00006968.jpg', 'Can you tell me the color of the dog?'], ['examples/VizWiz_test_00005604.jpg', 'What drink is this?'], ['examples/VizWiz_test_00006246.jpg', 'Can you please tell me what kind of tea this is?'], ['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']] | |
).launch() | |