File size: 3,297 Bytes
0ae6b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea784ed
0ae6b5c
4214d1d
0ae6b5c
 
 
e67e20f
 
 
 
 
 
 
 
 
 
 
0ae6b5c
e67e20f
 
 
 
 
 
2618f02
c6a9352
e67e20f
 
369f023
2618f02
 
f16c093
e67e20f
0ae6b5c
e67e20f
 
 
c08ba31
e67e20f
0ae6b5c
d2f0b33
 
 
 
 
 
 
0ae6b5c
 
2ef97f1
0ae6b5c
2a8a686
0ae6b5c
2a8a686
0ae6b5c
01775c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import gradio as gr

from models.blip import blip_decoder

image_size = 384
transform = transforms.Compose([
    transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ]) 

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
    
model = blip_decoder(pretrained=model_url, image_size=384, vit='large')
model.eval()
model = model.to(device)


from models.blip_vqa import blip_vqa

image_size_vq = 480
transform_vq = transforms.Compose([
    transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ]) 

model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
    
model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
model_vq.eval()
model_vq = model_vq.to(device)



def inference(raw_image, model_n, question, strategy):
    if model_n == 'Image Captioning':
        image = transform(raw_image).unsqueeze(0).to(device)   
        with torch.no_grad():
          if strategy == "Beam search":
              caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
          else:
              caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
          return 'caption: '+caption[0]

    else:   
        image_vq = transform_vq(raw_image).unsqueeze(0).to(device)  
        with torch.no_grad():
            answer = model_vq(image_vq, question, train=False, inference='generate') 
        return  'answer: '+answer[0]
    
inputs = [
    gr.Image(type='pil', interactive=False),
    gr.inputs.Radio(choices=['Image Captioning',"Visual Question Answering"],
    type="value",
    default="Image Captioning",
    label="Task"
),gr.inputs.Textbox(lines=2, label="Question"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Caption Decoding Strategy")]
outputs = gr.outputs.Textbox(label="Output")

title = "BLIP"

description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (Salesforce Research). We have now disable image uploading as of March 23. 2023. Click one of the examples to load them. Read more at the links below."

article = """<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>"""

gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starrynight.jpeg',"Image Captioning","None","Nucleus sampling"]]).launch(enable_queue=True)