merve HF staff commited on
Commit
b802c2a
β€’
1 Parent(s): fad59d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from PIL import Image
4
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
5
+ import spaces
6
+
7
+ @spaces.GPU
8
+ def infer_diagram(image, question):
9
+ model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ai2d-448").to("cuda")
10
+ processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ai2d-448")
11
+
12
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
13
+
14
+ predictions = model.generate(**inputs, max_new_tokens=100)
15
+ return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
16
+
17
+ @spaces.GPU
18
+ def infer_ocrvqa(image, question):
19
+ model = Pix2StructForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ocrvqa-896").to("cuda")
20
+ processor = Pix2StructProcessor.from_pretrained("google/paligemma-3b-ft-ocrvqa-896e")
21
+
22
+ inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")
23
+
24
+ predictions = model.generate(**inputs, max_new_tokens=100)
25
+ return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
26
+
27
+ @spaces.GPU
28
+ def infer_infographics(image, question):
29
+ model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-infovqa-896").to("cuda")
30
+ processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-infovqa-896")
31
+
32
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
33
+
34
+ predictions = model.generate(**inputs, max_new_tokens=100)
35
+ return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
36
+ @spaces.GPU
37
+ def infer_doc(image, question):
38
+ model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-docvqa-896").to("cuda")
39
+
40
+ processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-docvqa-896")
41
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
42
+ predictions = model.generate(**inputs, max_new_tokens=100)
43
+ return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
44
+
45
+ css = """
46
+ #mkd {
47
+ height: 500px;
48
+ overflow: auto;
49
+ border: 1px solid #ccc;
50
+ }
51
+ """
52
+
53
+ with gr.Blocks(css=css) as demo:
54
+ gr.HTML("<h1><center>PaliGemma Fine-tuned on Documents πŸ“„<center><h1>")
55
+ gr.HTML("<h3><center>This Space is built for you to compare different PaliGemma models fine-tuned on document tasks. ⚑</h3>")
56
+ gr.HTML("<h3><center>Each tab in this app demonstrates PaliGemma models fine-tuned on document question answering, infographics question answering, diagram understanding, and reading comprehension from images. πŸ“„πŸ“•πŸ“Š<h3>")
57
+ gr.HTML("<h3><center>Models are downloaded on the go, so first inference in each tab might take time if it's not already downloaded.<h3>")
58
+
59
+ with gr.Tab(label="Visual Question Answering over Documents"):
60
+ with gr.Row():
61
+ with gr.Column():
62
+ input_img = gr.Image(label="Input Document")
63
+ question = gr.Text(label="Question")
64
+ submit_btn = gr.Button(value="Submit")
65
+ output = gr.Text(label="Answer")
66
+ gr.Examples(
67
+ [["assets/docvqa_example.png", "How many items are sold?"]],
68
+ inputs = [input_img, question],
69
+ outputs = [output],
70
+ fn=infer_doc,
71
+ label='Click on any Examples below to get Document Question Answering results quickly πŸ‘‡'
72
+ )
73
+
74
+ submit_btn.click(infer_doc, [input_img, question], [output])
75
+
76
+ with gr.Tab(label="Visual Question Answering over Infographics"):
77
+ with gr.Row():
78
+ with gr.Column():
79
+ input_img = gr.Image(label="Input Image")
80
+ question = gr.Text(label="Question")
81
+ submit_btn = gr.Button(value="Submit")
82
+ output = gr.Text(label="Answer")
83
+ gr.Examples(
84
+ [["assets/infographics_example (1).jpeg", "What is this infographic about?"]],
85
+ inputs = [input_img, question],
86
+ outputs = [output],
87
+ fn=infer_infovqa,
88
+ label='Click on any Examples below to get Infographics QA results quickly πŸ‘‡'
89
+ )
90
+
91
+ submit_btn.click(infer_infographics, [input_img, question], [output])
92
+ with gr.Tab(label="Reading from Images"):
93
+ with gr.Row():
94
+ with gr.Column():
95
+ input_img = gr.Image(label="Input Document")
96
+ question = gr.Text(label="Question")
97
+ submit_btn = gr.Button(value="Submit")
98
+ output = gr.Text(label="Infer")
99
+ submit_btn.click(infer_ocrvqa, [input_img, question], [output])
100
+ gr.Examples(
101
+ [["assets/ocrvqa.jpg", "Who is the author of this book?"]],
102
+ inputs = [input_img, question],
103
+ outputs = [output],
104
+ fn=infer_doc,
105
+ label='Click on any Examples below to get UI question answering results quickly πŸ‘‡'
106
+ )
107
+ with gr.Tab(label="Diagram Understanding"):
108
+ with gr.Row():
109
+ with gr.Column():
110
+ input_img = gr.Image(label="Input Diagram")
111
+ question = gr.Text(label="Question")
112
+ submit_btn = gr.Button(value="Submit")
113
+ output = gr.Text(label="Infer")
114
+ submit_btn.click(infer_diagram, [input_img, question], [output])
115
+ gr.Examples(
116
+ [["assets/diagram.png", "What is the diagram showing?"]],
117
+ inputs = [input_img, question],
118
+ outputs = [output],
119
+ fn=infer_doc,
120
+ label='Click on any Examples below to get UI question answering results quickly πŸ‘‡'
121
+ )
122
+
123
+ demo.launch(debug=True)