Upload 3 files
Browse files- app.py +171 -148
- packages.txt +2 -3
- requirements.txt +6 -7
app.py
CHANGED
@@ -1,148 +1,171 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from typing import
|
3 |
-
import logging
|
4 |
-
import tempfile
|
5 |
-
import
|
6 |
-
import
|
7 |
-
from
|
8 |
-
import
|
9 |
-
import
|
10 |
-
import
|
11 |
-
import
|
12 |
-
from
|
13 |
-
import
|
14 |
-
import
|
15 |
-
|
16 |
-
# Set up logging
|
17 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
18 |
-
logger = logging.getLogger(__name__)
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
)
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import Dict
|
3 |
+
import logging
|
4 |
+
import tempfile
|
5 |
+
import io
|
6 |
+
import torch
|
7 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
8 |
+
from pdf2image import convert_from_bytes
|
9 |
+
from PIL import Image
|
10 |
+
import pytesseract
|
11 |
+
import docx2txt
|
12 |
+
from reportlab.lib.pagesizes import letter
|
13 |
+
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
|
14 |
+
from reportlab.lib.styles import getSampleStyleSheet
|
15 |
+
|
16 |
+
# Set up logging
|
17 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
class AdvancedDocProcessor:
|
21 |
+
def __init__(self):
|
22 |
+
# Initialize BART model for text cleaning and summarization
|
23 |
+
self.bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
|
24 |
+
self.bart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
|
25 |
+
|
26 |
+
# Initialize T5 model for text generation tasks
|
27 |
+
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
|
28 |
+
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
|
29 |
+
|
30 |
+
# Initialize pipeline for named entity recognition
|
31 |
+
self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
|
32 |
+
|
33 |
+
def extract_text(self, file_content: bytes, file_type: str) -> str:
|
34 |
+
"""Extract text from various file types."""
|
35 |
+
try:
|
36 |
+
if file_type == "application/pdf":
|
37 |
+
return self.extract_text_from_pdf(file_content)
|
38 |
+
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
39 |
+
return self.extract_text_from_docx(file_content)
|
40 |
+
elif file_type == "text/plain":
|
41 |
+
return file_content.decode('utf-8')
|
42 |
+
else:
|
43 |
+
raise ValueError(f"Unsupported file type: {file_type}")
|
44 |
+
except Exception as e:
|
45 |
+
logger.error(f"Error extracting text: {str(e)}")
|
46 |
+
return ""
|
47 |
+
|
48 |
+
def extract_text_from_pdf(self, pdf_content: bytes) -> str:
|
49 |
+
"""Extract text from PDF using OCR."""
|
50 |
+
images = convert_from_bytes(pdf_content)
|
51 |
+
text = ""
|
52 |
+
for image in images:
|
53 |
+
text += pytesseract.image_to_string(image)
|
54 |
+
return text
|
55 |
+
|
56 |
+
def extract_text_from_docx(self, docx_content: bytes) -> str:
|
57 |
+
"""Extract text from a DOCX file."""
|
58 |
+
return docx2txt.process(io.BytesIO(docx_content))
|
59 |
+
|
60 |
+
def clean_and_summarize_text(self, text: str) -> str:
|
61 |
+
"""Clean and summarize the text using BART."""
|
62 |
+
inputs = self.bart_tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
|
63 |
+
summary_ids = self.bart_model.generate(inputs["input_ids"], num_beams=4, max_length=150, early_stopping=True)
|
64 |
+
return self.bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
65 |
+
|
66 |
+
def process_with_t5(self, text: str, prompt: str) -> str:
|
67 |
+
"""Process the text with T5 based on the given prompt."""
|
68 |
+
input_text = f"{prompt} {text}"
|
69 |
+
input_ids = self.t5_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).input_ids
|
70 |
+
outputs = self.t5_model.generate(input_ids, max_length=150, num_return_sequences=1, temperature=0.7)
|
71 |
+
return self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
72 |
+
|
73 |
+
def extract_entities(self, text: str) -> str:
|
74 |
+
"""Extract named entities from the text."""
|
75 |
+
entities = self.ner_pipeline(text)
|
76 |
+
unique_entities = set((ent['word'], ent['entity']) for ent in entities)
|
77 |
+
return "\n".join([f"{word} ({entity})" for word, entity in unique_entities])
|
78 |
+
|
79 |
+
def process_document(self, file_content: bytes, file_type: str, prompt: str) -> Dict[str, str]:
|
80 |
+
raw_text = self.extract_text(file_content, file_type)
|
81 |
+
cleaned_text = self.clean_and_summarize_text(raw_text)
|
82 |
+
processed_text = self.process_with_t5(cleaned_text, prompt)
|
83 |
+
entities = self.extract_entities(raw_text)
|
84 |
+
|
85 |
+
return {
|
86 |
+
"original": raw_text,
|
87 |
+
"cleaned": cleaned_text,
|
88 |
+
"processed": processed_text,
|
89 |
+
"entities": entities
|
90 |
+
}
|
91 |
+
|
92 |
+
def infer_file_type(file_content: bytes) -> str:
|
93 |
+
"""Infer the file type from the byte content."""
|
94 |
+
if file_content.startswith(b'%PDF'):
|
95 |
+
return "application/pdf"
|
96 |
+
elif file_content.startswith(b'PK\x03\x04'):
|
97 |
+
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
98 |
+
else:
|
99 |
+
return "text/plain"
|
100 |
+
|
101 |
+
def create_gradio_interface():
|
102 |
+
processor = AdvancedDocProcessor()
|
103 |
+
|
104 |
+
def process_and_display(file, prompt, output_format):
|
105 |
+
file_content = file
|
106 |
+
file_type = infer_file_type(file_content)
|
107 |
+
results = processor.process_document(file_content, file_type, prompt)
|
108 |
+
|
109 |
+
if output_format == "txt":
|
110 |
+
output_path = save_as_txt(results)
|
111 |
+
elif output_format == "docx":
|
112 |
+
output_path = save_as_docx(results)
|
113 |
+
else: # pdf
|
114 |
+
output_path = save_as_pdf(results)
|
115 |
+
|
116 |
+
return (f"Original Text (first 500 chars):\n{results['original'][:500]}...\n\n"
|
117 |
+
f"Cleaned and Summarized Text:\n{results['cleaned']}\n\n"
|
118 |
+
f"Processed Text:\n{results['processed']}\n\n"
|
119 |
+
f"Extracted Entities:\n{results['entities']}"), output_path
|
120 |
+
|
121 |
+
iface = gr.Interface(
|
122 |
+
fn=process_and_display,
|
123 |
+
inputs=[
|
124 |
+
gr.File(label="Upload Document (PDF, DOCX, or TXT)", type="binary"),
|
125 |
+
gr.Textbox(label="Enter your prompt for processing", lines=3),
|
126 |
+
gr.Radio(["txt", "docx", "pdf"], label="Output Format", value="txt")
|
127 |
+
],
|
128 |
+
outputs=[
|
129 |
+
gr.Textbox(label="Processing Results", lines=30),
|
130 |
+
gr.File(label="Download Processed Document")
|
131 |
+
],
|
132 |
+
title="Advanced Document Processing Tool",
|
133 |
+
description="Upload a document (PDF, DOCX, or TXT) and enter a prompt to process and analyze the text using state-of-the-art NLP models.",
|
134 |
+
)
|
135 |
+
|
136 |
+
return iface
|
137 |
+
|
138 |
+
def save_as_txt(results: Dict[str, str]) -> str:
|
139 |
+
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as temp_file:
|
140 |
+
for key, value in results.items():
|
141 |
+
temp_file.write(f"{key.upper()}:\n{value}\n\n")
|
142 |
+
return temp_file.name
|
143 |
+
|
144 |
+
def save_as_docx(results: Dict[str, str]) -> str:
|
145 |
+
doc = docx.Document()
|
146 |
+
for key, value in results.items():
|
147 |
+
doc.add_heading(key.capitalize(), level=1)
|
148 |
+
doc.add_paragraph(value)
|
149 |
+
|
150 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp:
|
151 |
+
doc.save(tmp.name)
|
152 |
+
return tmp.name
|
153 |
+
|
154 |
+
def save_as_pdf(results: Dict[str, str]) -> str:
|
155 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
|
156 |
+
doc = SimpleDocTemplate(tmp.name, pagesize=letter)
|
157 |
+
styles = getSampleStyleSheet()
|
158 |
+
story = []
|
159 |
+
|
160 |
+
for key, value in results.items():
|
161 |
+
story.append(Paragraph(key.capitalize(), styles['Heading1']))
|
162 |
+
story.append(Paragraph(value, styles['BodyText']))
|
163 |
+
story.append(Spacer(1, 12))
|
164 |
+
|
165 |
+
doc.build(story)
|
166 |
+
return tmp.name
|
167 |
+
|
168 |
+
# Launch the Gradio app
|
169 |
+
if __name__ == "__main__":
|
170 |
+
iface = create_gradio_interface()
|
171 |
+
iface.launch()
|
packages.txt
CHANGED
@@ -1,3 +1,2 @@
|
|
1 |
-
tesseract-ocr
|
2 |
-
libtesseract-dev
|
3 |
-
libleptonica-dev
|
|
|
1 |
+
tesseract-ocr
|
2 |
+
libtesseract-dev
|
|
requirements.txt
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
-
gradio
|
2 |
-
pytesseract
|
3 |
-
PyMuPDF
|
4 |
-
Pillow
|
5 |
-
torch
|
6 |
-
transformers
|
7 |
-
tqdm
|
|
|
1 |
+
gradio
|
2 |
+
pytesseract
|
3 |
+
PyMuPDF
|
4 |
+
Pillow
|
5 |
+
torch
|
6 |
+
transformers
|
|