Scharbhen commited on
Commit
4bc69f5
1 Parent(s): 485c677

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -73
app.py CHANGED
@@ -4,37 +4,16 @@ from PIL import Image
4
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
5
  import spaces
6
 
7
- @spaces.GPU
8
- def infer_diagram(image, question):
9
- model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ai2d-448").to("cuda")
10
- processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ai2d-448")
11
-
12
- inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
13
-
14
- predictions = model.generate(**inputs, max_new_tokens=100)
15
- return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
16
 
17
  @spaces.GPU
18
  def infer_ocrvqa(image, question):
19
  model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ocrvqa-896").to("cuda")
20
  processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ocrvqa-896")
21
-
22
  inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")
23
-
24
- predictions = model.generate(**inputs, max_new_tokens=200)
25
-
26
  return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
27
 
28
  @spaces.GPU
29
- def infer_infographics(image, question):
30
- model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-infovqa-896").to("cuda")
31
- processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-infovqa-896")
32
-
33
- inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
34
-
35
- predictions = model.generate(**inputs, max_new_tokens=100)
36
- return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
37
- @spaces.GPU
38
  def infer_doc(image, question):
39
  model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-docvqa-896").to("cuda")
40
 
@@ -52,45 +31,22 @@ css = """
52
  """
53
 
54
  with gr.Blocks(css=css) as demo:
55
- gr.HTML("<h1><center>PaliGemma Fine-tuned on Documents 📄<center><h1>")
56
- gr.HTML("<h3><center>This Space is built for you to compare different PaliGemma models fine-tuned on document tasks. ⚡</h3>")
57
- gr.HTML("<h3><center>Each tab in this app demonstrates PaliGemma models fine-tuned on document question answering, infographics question answering, diagram understanding, and reading comprehension from images. 📄📕📊<h3>")
58
- gr.HTML("<h3><center>Models are downloaded on the go, so first inference in each tab might take time if it's not already downloaded.<h3>")
59
 
60
- with gr.Tab(label="Visual Question Answering over Documents"):
61
  with gr.Row():
62
  with gr.Column():
63
  input_img = gr.Image(label="Input Document")
64
  question = gr.Text(label="Question")
65
  submit_btn = gr.Button(value="Submit")
66
  output = gr.Text(label="Answer")
67
- gr.Examples(
68
- [["assets/docvqa_example.png", "How many items are sold?"]],
69
- inputs = [input_img, question],
70
- outputs = [output],
71
- fn=infer_doc,
72
- label='Click on any Examples below to get Document Question Answering results quickly 👇'
73
- )
74
 
75
  submit_btn.click(infer_doc, [input_img, question], [output])
76
 
77
- with gr.Tab(label="Visual Question Answering over Infographics"):
78
- with gr.Row():
79
- with gr.Column():
80
- input_img = gr.Image(label="Input Image")
81
- question = gr.Text(label="Question")
82
- submit_btn = gr.Button(value="Submit")
83
- output = gr.Text(label="Answer")
84
- gr.Examples(
85
- [["assets/infographics_example (1).jpeg", "What is this infographic about?"]],
86
- inputs = [input_img, question],
87
- outputs = [output],
88
- fn=infer_infographics,
89
- label='Click on any Examples below to get Infographics QA results quickly 👇'
90
- )
91
-
92
- submit_btn.click(infer_infographics, [input_img, question], [output])
93
- with gr.Tab(label="Reading from Images"):
94
  with gr.Row():
95
  with gr.Column():
96
  input_img = gr.Image(label="Input Document")
@@ -98,27 +54,5 @@ with gr.Blocks(css=css) as demo:
98
  submit_btn = gr.Button(value="Submit")
99
  output = gr.Text(label="Infer")
100
  submit_btn.click(infer_ocrvqa, [input_img, question], [output])
101
- gr.Examples(
102
- [["assets/ocrvqa.jpg", "Who is the author of this book?"]],
103
- inputs = [input_img, question],
104
- outputs = [output],
105
- fn=infer_doc,
106
- label='Click on any Examples below to get image reading comprehension results quickly 👇'
107
- )
108
- with gr.Tab(label="Diagram Understanding"):
109
- with gr.Row():
110
- with gr.Column():
111
- input_img = gr.Image(label="Input Diagram")
112
- question = gr.Text(label="Question")
113
- submit_btn = gr.Button(value="Submit")
114
- output = gr.Text(label="Infer")
115
- submit_btn.click(infer_diagram, [input_img, question], [output])
116
- gr.Examples(
117
- [["assets/diagram.png", "What is the diagram showing?"]],
118
- inputs = [input_img, question],
119
- outputs = [output],
120
- fn=infer_doc,
121
- label='Click on any Examples below to get diagram understanding results quickly 👇'
122
- )
123
 
124
  demo.launch(debug=True)
 
4
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
5
  import spaces
6
 
 
 
 
 
 
 
 
 
 
7
 
8
  @spaces.GPU
9
  def infer_ocrvqa(image, question):
10
  model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ocrvqa-896").to("cuda")
11
  processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ocrvqa-896")
 
12
  inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")
13
+ predictions = model.generate(**inputs, max_new_tokens=100)
 
 
14
  return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
15
 
16
  @spaces.GPU
 
 
 
 
 
 
 
 
 
17
  def infer_doc(image, question):
18
  model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-docvqa-896").to("cuda")
19
 
 
31
  """
32
 
33
  with gr.Blocks(css=css) as demo:
34
+ gr.HTML("<h1><center>PaliGemma для VQA/OCR 📄<center><h1>")
35
+ gr.HTML("<h3><center>Использование модели "как есть" без файнтюнинга на документах. ⚡</h3>")
36
+
 
37
 
38
+ with gr.Tab(label="Ответы на вопросы по документам"):
39
  with gr.Row():
40
  with gr.Column():
41
  input_img = gr.Image(label="Input Document")
42
  question = gr.Text(label="Question")
43
  submit_btn = gr.Button(value="Submit")
44
  output = gr.Text(label="Answer")
 
 
 
 
 
 
 
45
 
46
  submit_btn.click(infer_doc, [input_img, question], [output])
47
 
48
+
49
+ with gr.Tab(label="Чтение текста со сканов"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with gr.Row():
51
  with gr.Column():
52
  input_img = gr.Image(label="Input Document")
 
54
  submit_btn = gr.Button(value="Submit")
55
  output = gr.Text(label="Infer")
56
  submit_btn.click(infer_ocrvqa, [input_img, question], [output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  demo.launch(debug=True)