ocrtest / app.py
kopeck's picture
Upload 3 files
40b0c63 verified
raw
history blame
7.47 kB
import gradio as gr
from typing import Dict
import logging
import tempfile
import io
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from pdf2image import convert_from_bytes
from PIL import Image
import pytesseract
import docx2txt
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class AdvancedDocProcessor:
def __init__(self):
# Initialize BART model for text cleaning and summarization
self.bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
self.bart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
# Initialize T5 model for text generation tasks
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
# Initialize pipeline for named entity recognition
self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
def extract_text(self, file_content: bytes, file_type: str) -> str:
"""Extract text from various file types."""
try:
if file_type == "application/pdf":
return self.extract_text_from_pdf(file_content)
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return self.extract_text_from_docx(file_content)
elif file_type == "text/plain":
return file_content.decode('utf-8')
else:
raise ValueError(f"Unsupported file type: {file_type}")
except Exception as e:
logger.error(f"Error extracting text: {str(e)}")
return ""
def extract_text_from_pdf(self, pdf_content: bytes) -> str:
"""Extract text from PDF using OCR."""
images = convert_from_bytes(pdf_content)
text = ""
for image in images:
text += pytesseract.image_to_string(image)
return text
def extract_text_from_docx(self, docx_content: bytes) -> str:
"""Extract text from a DOCX file."""
return docx2txt.process(io.BytesIO(docx_content))
def clean_and_summarize_text(self, text: str) -> str:
"""Clean and summarize the text using BART."""
inputs = self.bart_tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
summary_ids = self.bart_model.generate(inputs["input_ids"], num_beams=4, max_length=150, early_stopping=True)
return self.bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def process_with_t5(self, text: str, prompt: str) -> str:
"""Process the text with T5 based on the given prompt."""
input_text = f"{prompt} {text}"
input_ids = self.t5_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).input_ids
outputs = self.t5_model.generate(input_ids, max_length=150, num_return_sequences=1, temperature=0.7)
return self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
def extract_entities(self, text: str) -> str:
"""Extract named entities from the text."""
entities = self.ner_pipeline(text)
unique_entities = set((ent['word'], ent['entity']) for ent in entities)
return "\n".join([f"{word} ({entity})" for word, entity in unique_entities])
def process_document(self, file_content: bytes, file_type: str, prompt: str) -> Dict[str, str]:
raw_text = self.extract_text(file_content, file_type)
cleaned_text = self.clean_and_summarize_text(raw_text)
processed_text = self.process_with_t5(cleaned_text, prompt)
entities = self.extract_entities(raw_text)
return {
"original": raw_text,
"cleaned": cleaned_text,
"processed": processed_text,
"entities": entities
}
def infer_file_type(file_content: bytes) -> str:
"""Infer the file type from the byte content."""
if file_content.startswith(b'%PDF'):
return "application/pdf"
elif file_content.startswith(b'PK\x03\x04'):
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
else:
return "text/plain"
def create_gradio_interface():
processor = AdvancedDocProcessor()
def process_and_display(file, prompt, output_format):
file_content = file
file_type = infer_file_type(file_content)
results = processor.process_document(file_content, file_type, prompt)
if output_format == "txt":
output_path = save_as_txt(results)
elif output_format == "docx":
output_path = save_as_docx(results)
else: # pdf
output_path = save_as_pdf(results)
return (f"Original Text (first 500 chars):\n{results['original'][:500]}...\n\n"
f"Cleaned and Summarized Text:\n{results['cleaned']}\n\n"
f"Processed Text:\n{results['processed']}\n\n"
f"Extracted Entities:\n{results['entities']}"), output_path
iface = gr.Interface(
fn=process_and_display,
inputs=[
gr.File(label="Upload Document (PDF, DOCX, or TXT)", type="binary"),
gr.Textbox(label="Enter your prompt for processing", lines=3),
gr.Radio(["txt", "docx", "pdf"], label="Output Format", value="txt")
],
outputs=[
gr.Textbox(label="Processing Results", lines=30),
gr.File(label="Download Processed Document")
],
title="Advanced Document Processing Tool",
description="Upload a document (PDF, DOCX, or TXT) and enter a prompt to process and analyze the text using state-of-the-art NLP models.",
)
return iface
def save_as_txt(results: Dict[str, str]) -> str:
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as temp_file:
for key, value in results.items():
temp_file.write(f"{key.upper()}:\n{value}\n\n")
return temp_file.name
def save_as_docx(results: Dict[str, str]) -> str:
doc = docx.Document()
for key, value in results.items():
doc.add_heading(key.capitalize(), level=1)
doc.add_paragraph(value)
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp:
doc.save(tmp.name)
return tmp.name
def save_as_pdf(results: Dict[str, str]) -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
doc = SimpleDocTemplate(tmp.name, pagesize=letter)
styles = getSampleStyleSheet()
story = []
for key, value in results.items():
story.append(Paragraph(key.capitalize(), styles['Heading1']))
story.append(Paragraph(value, styles['BodyText']))
story.append(Spacer(1, 12))
doc.build(story)
return tmp.name
# Launch the Gradio app
if __name__ == "__main__":
iface = create_gradio_interface()
iface.launch()