Ahsen Khaliq commited on
Commit
e67e20f
1 Parent(s): 2ef97f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -8
app.py CHANGED
@@ -30,16 +30,38 @@ model = blip_decoder(pretrained=model_url, image_size=384, vit='base')
30
  model.eval()
31
  model = model.to(device)
32
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def inference(raw_image):
35
- image = transform(raw_image).unsqueeze(0).to(device)
36
- with torch.no_grad():
37
- caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
38
- print('caption: '+caption[0])
 
 
 
 
 
 
 
39
 
40
- return 'caption: '+caption[0]
 
 
 
 
41
 
42
- inputs = gr.inputs.Image(type='pil')
43
  outputs = gr.outputs.Textbox(label="Output")
44
 
45
  title = "BLIP"
@@ -49,4 +71,4 @@ description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training f
49
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>"
50
 
51
 
52
- gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starry.jpg']]).launch(enable_queue=True,cache_examples=True)
 
30
  model.eval()
31
  model = model.to(device)
32
 
33
+
34
+ from models.blip_vqa import blip_vqa
35
+
36
+ image_size_vq = 480
37
+ transform_vq = transforms.Compose([
38
+ transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
41
+ ])
42
+
43
+ model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
44
 
45
+ model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
46
+ model_vq.eval()
47
+ model_vq = model_vq.to(device)
48
+
49
+
50
+
51
+ def inference(raw_image, model, question):
52
+ if model == 'Image Captioning':
53
+ image = transform(raw_image).unsqueeze(0).to(device)
54
+ with torch.no_grad():
55
+ caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
56
+ return 'caption: '+caption[0]
57
 
58
+ else:
59
+ image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
60
+ with torch.no_grad():
61
+ answer = model(image_vq, question, train=False, inference='generate')
62
+ return 'answer: '+answer[0]
63
 
64
+ inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning',"Visual Question Answering"], type="value", default="Image Captioning", label="Model"),"textbox"]
65
  outputs = gr.outputs.Textbox(label="Output")
66
 
67
  title = "BLIP"
 
71
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>"
72
 
73
 
74
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starry.jpg',"Image Captioning",""]]).launch(enable_queue=True,cache_examples=True)