File size: 5,481 Bytes
b395c00
 
 
cbadb1a
b395c00
 
 
 
 
 
 
 
de2dda2
 
b395c00
 
 
 
de2dda2
 
b395c00
 
 
 
 
 
 
c2c8861
de2dda2
 
c2c8861
b395c00
 
 
 
 
 
 
 
 
 
eda8f6b
36f181e
9acd641
c2c8861
 
 
 
 
 
 
 
 
 
36f181e
 
dc6e60c
61922b8
8d1868a
52c29b8
b395c00
4056078
b395c00
 
 
 
 
1d036b5
a2b1833
fe6ca74
b395c00
 
 
 
 
 
 
 
 
 
 
0647c43
b395c00
 
 
 
 
 
de2dda2
b395c00
c2c8861
b395c00
de2dda2
b395c00
c2c8861
b395c00
c2c8861
b395c00
 
 
c567393
b395c00
 
 
 
 
 
 
 
 
 
 
 
 
a298ea6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForQuestionAnswering, ViltForQuestionAnswering
import torch
import math

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")

# git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-vqav2")
# git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-vqav2")

blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

# blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")

vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

device = "cuda" if torch.cuda.is_available() else "cpu"

git_model_base.to(device)
# blip_model_base.to(device)
#git_model_large.to(device)
#blip_model_large.to(device)
# vilt_model.to(device)

def generate_answer_git(processor, model, image, question):
    # prepare image
    pixel_values = processor(images=image, return_tensors="pt").pixel_values

    # prepare question
    input_ids = processor(text=question, add_special_tokens=False).input_ids
    input_ids = [processor.tokenizer.cls_token_id] + input_ids
    input_ids = torch.tensor(input_ids).unsqueeze(0)
    
    generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, return_dict_in_generate=True, output_scores=True)
    print('scores:')
    print(generated_ids.scores)
    # scoresList0 = torch.softmax(generated_ids.scores[0], dim=1).flatten().tolist()
    # print(scoresList0)
    # scoresList1 = torch.softmax(generated_ids.scores[1], dim=1).flatten().tolist()
    # print(scoresList1)
    idx = generated_ids.scores[0].argmax(-1).item()
    idx1 = generated_ids.scores[1].argmax(-1).item()
    print(idx, idx1)
    ans = model.config.id2label[idx]
    ans1 = model.config.id2label[idx1]
    print(ans, ans1)
    print('sequences:')
    print(generated_ids.sequences)
    print(generated_ids)
    generated_answer = processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)
    print(generated_answer)
    
   
    return 'haha'


def generate_answer_blip(processor, model, image, question):
    # prepare image + question
    inputs = processor(images=image, text=question, return_tensors="pt")
    print('blip')
    generated_ids = model.generate(**inputs, max_length=50, output_scores=True)
    print(generated_ids)
    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return generated_answer


def generate_answer_vilt(processor, model, image, question):
    # prepare image + question
    encoding = processor(images=image, text=question, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**encoding)
    predicted_class_idx = outputs.logits.argmax(-1).item()
    return model.config.id2label[predicted_class_idx]


def generate_answers(image, question):
    answer_git_base = generate_answer_git(git_processor_base, git_model_base, image, question)

    # answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question)

    # answer_blip_base = generate_answer_blip(blip_processor_base, blip_model_base, image, question)

    # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)

    # answer_vilt = generate_answer_vilt(vilt_processor, vilt_model, image, question)

    return answer_git_base

   
examples = [["cats.jpg", "How many cats are there?"], ["stop_sign.png", "What's behind the stop sign?"], ["astronaut.jpg", "What's the astronaut riding on?"]]
outputs = [gr.outputs.Textbox(label="Answer generated by GIT-base"), gr.outputs.Textbox(label="Answer generated by BLIP-base"), gr.outputs.Textbox(label="Answer generated by ViLT")] 

title = "Interactive demo: comparing visual question answering (VQA) models"
description = "Gradio Demo to compare GIT, BLIP and ViLT, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_answers, 
                         inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch()