SmallCapDemo / app.py
RitaParadaRamos's picture
Upload app.py
79b66b6
raw
history blame contribute delete
No virus
3.88 kB
import requests
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, AutoTokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM
from transformers.models.auto.configuration_auto import AutoConfig
from src.vision_encoder_decoder import SmallCap, SmallCapConfig
from src.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel
from src.utils import prep_strings, postprocess_preds
import json
from src.retrieve_caps import *
from PIL import Image
from torchvision import transforms
from src.opt import ThisOPTConfig, ThisOPTForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
# load feature extractor
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
# load and configure tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.pad_token = '!'
tokenizer.eos_token = '.'
# load model
# AutoConfig.register("this_gpt2", ThisGPT2Config)
# AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel)
# AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel)
# AutoConfig.register("smallcap", SmallCapConfig)
# AutoModel.register(SmallCapConfig, SmallCap)
# model = AutoModel.from_pretrained("Yova/SmallCap7M")
AutoConfig.register("this_opt", ThisOPTConfig)
AutoModel.register(ThisOPTConfig, ThisOPTForCausalLM)
AutoModelForCausalLM.register(ThisOPTConfig, ThisOPTForCausalLM)
AutoConfig.register("smallcap", SmallCapConfig)
AutoModel.register(SmallCapConfig, SmallCap)
model = AutoModel.from_pretrained("Yova/SmallCapOPT7M")
model= model.to(device)
template = open('src/template.txt').read().strip() + ' '
# precompute captions for retrieval
captions = json.load(open('coco_index_captions.json'))
retrieval_model, feature_extractor_retrieval = clip.load("RN50x64", device=device)
retrieval_index = faiss.read_index('coco_index')
#res = faiss.StandardGpuResources()
#retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index)
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def retrieve_caps(image_embedding, index, k=4):
xq = image_embedding.astype(np.float32)
faiss.normalize_L2(xq)
D, I = index.search(xq, k)
return I
def classify_image(image):
inp = transforms.ToTensor()(image)
pixel_values_retrieval = feature_extractor_retrieval(image).to(device)
with torch.no_grad():
image_embedding = retrieval_model.encode_image(pixel_values_retrieval.unsqueeze(0)).cpu().numpy()
nns = retrieve_caps(image_embedding, retrieval_index)[0]
caps = [captions[i] for i in nns][:4]
# prepare prompt
decoder_input_ids = prep_strings('', tokenizer, template=template, retrieved_caps=caps, k=4, is_test=True)
# generate caption
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
with torch.no_grad():
pred = model.generate(pixel_values.to(device),
decoder_input_ids=torch.tensor([decoder_input_ids]).to(device),
max_new_tokens=25, no_repeat_ngram_size=0, length_penalty=0,
min_length=1, num_beams=3, eos_token_id=tokenizer.eos_token_id)
#inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
#prediction = inception_net.predict(inp).flatten()
retrieved_caps="Retrieved captions: \n{}\n{}\n{}\n{}".format(*caps)
#return retrieved_caps + "\n\n\n Generated caption:\n" + str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer))
return str(postprocess_preds(tokenizer.decode(pred[0]), tokenizer)) + "\n\n\n"+ retrieved_caps
image = gr.Image(type="pil")
textbox = gr.Textbox(placeholder="Generated caption and retrieved captions...", lines=4)
title = "SmallCap Demo"
gr.Interface(
fn=classify_image, inputs=image, outputs=textbox, title=title
).launch()