Spaces:
Build error
Build error
File size: 3,297 Bytes
0ae6b5c ea784ed 0ae6b5c 4214d1d 0ae6b5c e67e20f 0ae6b5c e67e20f 2618f02 c6a9352 e67e20f 369f023 2618f02 f16c093 e67e20f 0ae6b5c e67e20f c08ba31 e67e20f 0ae6b5c d2f0b33 0ae6b5c 2ef97f1 0ae6b5c 2a8a686 0ae6b5c 2a8a686 0ae6b5c 01775c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import gradio as gr
from models.blip import blip_decoder
image_size = 384
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
model.eval()
model = model.to(device)
from models.blip_vqa import blip_vqa
image_size_vq = 480
transform_vq = transforms.Compose([
transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
model_vq.eval()
model_vq = model_vq.to(device)
def inference(raw_image, model_n, question, strategy):
if model_n == 'Image Captioning':
image = transform(raw_image).unsqueeze(0).to(device)
with torch.no_grad():
if strategy == "Beam search":
caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
else:
caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
return 'caption: '+caption[0]
else:
image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
with torch.no_grad():
answer = model_vq(image_vq, question, train=False, inference='generate')
return 'answer: '+answer[0]
inputs = [
gr.Image(type='pil', interactive=False),
gr.inputs.Radio(choices=['Image Captioning',"Visual Question Answering"],
type="value",
default="Image Captioning",
label="Task"
),gr.inputs.Textbox(lines=2, label="Question"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Caption Decoding Strategy")]
outputs = gr.outputs.Textbox(label="Output")
title = "BLIP"
description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (Salesforce Research). We have now disable image uploading as of March 23. 2023. Click one of the examples to load them. Read more at the links below."
article = """<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>"""
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starrynight.jpeg',"Image Captioning","None","Nucleus sampling"]]).launch(enable_queue=True) |