Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
import re
|
3 |
+
from PIL import Image
|
4 |
+
import requests
|
5 |
+
from nougat.dataset.rasterize import rasterize_paper
|
6 |
+
|
7 |
+
from transformers import NougatProcessor, VisionEncoderDecoderModel
|
8 |
+
import torch
|
9 |
+
|
10 |
+
processor = NougatProcessor.from_pretrained("nielsr/nougat")
|
11 |
+
model = VisionEncoderDecoderModel.from_pretrained("nielsr/nougat")
|
12 |
+
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
model.to(device)
|
15 |
+
|
16 |
+
|
17 |
+
def get_pdf(pdf_link):
|
18 |
+
unique_filename = f"{os.getcwd()}/downloaded_paper_{uuid.uuid4().hex}.pdf"
|
19 |
+
|
20 |
+
response = requests.get(pdf_link)
|
21 |
+
|
22 |
+
if response.status_code == 200:
|
23 |
+
with open(unique_filename, 'wb') as pdf_file:
|
24 |
+
pdf_file.write(response.content)
|
25 |
+
print("PDF downloaded successfully.")
|
26 |
+
else:
|
27 |
+
print("Failed to download the PDF.")
|
28 |
+
return unique_filename
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def predict(image):
|
33 |
+
# prepare PDF image for the model
|
34 |
+
image = Image.open(image)
|
35 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
36 |
+
|
37 |
+
# generate transcription (here we only generate 30 tokens)
|
38 |
+
outputs = model.generate(
|
39 |
+
pixel_values.to(device),
|
40 |
+
min_length=1,
|
41 |
+
max_new_tokens=30,
|
42 |
+
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
43 |
+
)
|
44 |
+
|
45 |
+
sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
46 |
+
sequence = processor.post_process_generation(sequence, fix_markdown=False)
|
47 |
+
return sequence
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def inference(pdf_file, pdf_link):
|
52 |
+
if pdf_file is None:
|
53 |
+
if pdf_link == '':
|
54 |
+
print("No file is uploaded and No link is provided")
|
55 |
+
return "No data provided. Upload a pdf file or provide a pdf link and try again!"
|
56 |
+
else:
|
57 |
+
file_name = get_pdf(pdf_link)
|
58 |
+
else:
|
59 |
+
file_name = pdf_file.name
|
60 |
+
pdf_name = pdf_file.name.split('/')[-1].split('.')[0]
|
61 |
+
|
62 |
+
images = rasterize_paper(file_name, return_pil=True)
|
63 |
+
sequence = ""
|
64 |
+
# infer for every page and concat
|
65 |
+
for image in images:
|
66 |
+
sequence += predict(image)
|
67 |
+
|
68 |
+
|
69 |
+
content = sequence.replace(r'\(', '$').replace(r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$')
|
70 |
+
return content
|
71 |
+
|
72 |
+
import gradio as gr
|
73 |
+
import uuid
|
74 |
+
import os
|
75 |
+
import requests
|
76 |
+
import re
|
77 |
+
|
78 |
+
css = """
|
79 |
+
#mkd {
|
80 |
+
height: 500px;
|
81 |
+
overflow: auto;
|
82 |
+
border: 1px solid #ccc;
|
83 |
+
}
|
84 |
+
"""
|
85 |
+
|
86 |
+
with gr.Blocks(css=css) as demo:
|
87 |
+
gr.HTML("<h1><center>Nougat: Neural Optical Understanding for Academic Documents 🍫<center><h1>")
|
88 |
+
gr.HTML("<h3><center>Lukas Blecher et al. <a href='https://arxiv.org/pdf/2308.13418.pdf' target='_blank'>Paper</a>, <a href='https://facebookresearch.github.io/nougat/'>Project</a><center></h3>")
|
89 |
+
gr.HTML("<h3><center>This demo is based on transformers implementation of Nougat 🤗<center><h3>")
|
90 |
+
|
91 |
+
|
92 |
+
with gr.Row():
|
93 |
+
mkd = gr.Markdown('<h4><center>Upload a PDF</center></h4>',scale=1)
|
94 |
+
mkd = gr.Markdown('<h4><center><i>OR</i></center></h4>',scale=1)
|
95 |
+
mkd = gr.Markdown('<h4><center>Provide a PDF link</center></h4>',scale=1)
|
96 |
+
|
97 |
+
|
98 |
+
with gr.Row():
|
99 |
+
mkd = gr.Markdown("Upload a PDF",scale=1)
|
100 |
+
mkd = gr.Markdown('OR',scale=1)
|
101 |
+
mkd = gr.Markdown('Provide a PDF link',scale=1)
|
102 |
+
|
103 |
+
with gr.Row(equal_height=True):
|
104 |
+
pdf_file = gr.File(label='PDF 📑', file_count='single', scale=1)
|
105 |
+
pdf_link = gr.Textbox(placeholder='Enter an arxiv link here', label='Link to Paper🔗', scale=1)
|
106 |
+
|
107 |
+
with gr.Row():
|
108 |
+
btn = gr.Button('Run Nougat 🍫')
|
109 |
+
clr = gr.Button('Clear 🧼')
|
110 |
+
|
111 |
+
output_headline = gr.Markdown("PDF converted to markup language through Nougat-OCR👇")
|
112 |
+
parsed_output = gr.Markdown(elem_id='mkd', value='OCR Output 📝')
|
113 |
+
|
114 |
+
btn.click(inference, [pdf_file, pdf_link], parsed_output )
|
115 |
+
clr.click(lambda : (gr.update(value=None),
|
116 |
+
gr.update(value=None),
|
117 |
+
gr.update(value=None)),
|
118 |
+
[],
|
119 |
+
[pdf_file, pdf_link, parsed_output]
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
demo.queue()
|
125 |
+
demo.launch(debug=True)
|