|
import gradio as gr
|
|
from typing import Dict
|
|
import logging
|
|
import tempfile
|
|
import io
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
from pdf2image import convert_from_bytes
|
|
from PIL import Image
|
|
import pytesseract
|
|
import docx2txt
|
|
from reportlab.lib.pagesizes import letter
|
|
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
|
|
from reportlab.lib.styles import getSampleStyleSheet
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AdvancedDocProcessor:
|
|
def __init__(self):
|
|
|
|
self.bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
|
|
self.bart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
|
|
|
|
|
|
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
|
|
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
|
|
|
|
|
|
self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
|
|
|
|
def extract_text(self, file_content: bytes, file_type: str) -> str:
|
|
"""Extract text from various file types."""
|
|
try:
|
|
if file_type == "application/pdf":
|
|
return self.extract_text_from_pdf(file_content)
|
|
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
|
return self.extract_text_from_docx(file_content)
|
|
elif file_type == "text/plain":
|
|
return file_content.decode('utf-8')
|
|
else:
|
|
raise ValueError(f"Unsupported file type: {file_type}")
|
|
except Exception as e:
|
|
logger.error(f"Error extracting text: {str(e)}")
|
|
return ""
|
|
|
|
def extract_text_from_pdf(self, pdf_content: bytes) -> str:
|
|
"""Extract text from PDF using OCR."""
|
|
images = convert_from_bytes(pdf_content)
|
|
text = ""
|
|
for image in images:
|
|
text += pytesseract.image_to_string(image)
|
|
return text
|
|
|
|
def extract_text_from_docx(self, docx_content: bytes) -> str:
|
|
"""Extract text from a DOCX file."""
|
|
return docx2txt.process(io.BytesIO(docx_content))
|
|
|
|
def clean_and_summarize_text(self, text: str) -> str:
|
|
"""Clean and summarize the text using BART."""
|
|
inputs = self.bart_tokenizer([text], max_length=1024, return_tensors="pt", truncation=True)
|
|
summary_ids = self.bart_model.generate(inputs["input_ids"], num_beams=4, max_length=150, early_stopping=True)
|
|
return self.bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
|
|
|
def process_with_t5(self, text: str, prompt: str) -> str:
|
|
"""Process the text with T5 based on the given prompt."""
|
|
input_text = f"{prompt} {text}"
|
|
input_ids = self.t5_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).input_ids
|
|
outputs = self.t5_model.generate(input_ids, max_length=150, num_return_sequences=1, temperature=0.7)
|
|
return self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
def extract_entities(self, text: str) -> str:
|
|
"""Extract named entities from the text."""
|
|
entities = self.ner_pipeline(text)
|
|
unique_entities = set((ent['word'], ent['entity']) for ent in entities)
|
|
return "\n".join([f"{word} ({entity})" for word, entity in unique_entities])
|
|
|
|
def process_document(self, file_content: bytes, file_type: str, prompt: str) -> Dict[str, str]:
|
|
raw_text = self.extract_text(file_content, file_type)
|
|
cleaned_text = self.clean_and_summarize_text(raw_text)
|
|
processed_text = self.process_with_t5(cleaned_text, prompt)
|
|
entities = self.extract_entities(raw_text)
|
|
|
|
return {
|
|
"original": raw_text,
|
|
"cleaned": cleaned_text,
|
|
"processed": processed_text,
|
|
"entities": entities
|
|
}
|
|
|
|
def infer_file_type(file_content: bytes) -> str:
|
|
"""Infer the file type from the byte content."""
|
|
if file_content.startswith(b'%PDF'):
|
|
return "application/pdf"
|
|
elif file_content.startswith(b'PK\x03\x04'):
|
|
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
else:
|
|
return "text/plain"
|
|
|
|
def create_gradio_interface():
|
|
processor = AdvancedDocProcessor()
|
|
|
|
def process_and_display(file, prompt, output_format):
|
|
file_content = file
|
|
file_type = infer_file_type(file_content)
|
|
results = processor.process_document(file_content, file_type, prompt)
|
|
|
|
if output_format == "txt":
|
|
output_path = save_as_txt(results)
|
|
elif output_format == "docx":
|
|
output_path = save_as_docx(results)
|
|
else:
|
|
output_path = save_as_pdf(results)
|
|
|
|
return (f"Original Text (first 500 chars):\n{results['original'][:500]}...\n\n"
|
|
f"Cleaned and Summarized Text:\n{results['cleaned']}\n\n"
|
|
f"Processed Text:\n{results['processed']}\n\n"
|
|
f"Extracted Entities:\n{results['entities']}"), output_path
|
|
|
|
iface = gr.Interface(
|
|
fn=process_and_display,
|
|
inputs=[
|
|
gr.File(label="Upload Document (PDF, DOCX, or TXT)", type="binary"),
|
|
gr.Textbox(label="Enter your prompt for processing", lines=3),
|
|
gr.Radio(["txt", "docx", "pdf"], label="Output Format", value="txt")
|
|
],
|
|
outputs=[
|
|
gr.Textbox(label="Processing Results", lines=30),
|
|
gr.File(label="Download Processed Document")
|
|
],
|
|
title="Advanced Document Processing Tool",
|
|
description="Upload a document (PDF, DOCX, or TXT) and enter a prompt to process and analyze the text using state-of-the-art NLP models.",
|
|
)
|
|
|
|
return iface
|
|
|
|
def save_as_txt(results: Dict[str, str]) -> str:
|
|
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as temp_file:
|
|
for key, value in results.items():
|
|
temp_file.write(f"{key.upper()}:\n{value}\n\n")
|
|
return temp_file.name
|
|
|
|
def save_as_docx(results: Dict[str, str]) -> str:
|
|
doc = docx.Document()
|
|
for key, value in results.items():
|
|
doc.add_heading(key.capitalize(), level=1)
|
|
doc.add_paragraph(value)
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp:
|
|
doc.save(tmp.name)
|
|
return tmp.name
|
|
|
|
def save_as_pdf(results: Dict[str, str]) -> str:
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
|
|
doc = SimpleDocTemplate(tmp.name, pagesize=letter)
|
|
styles = getSampleStyleSheet()
|
|
story = []
|
|
|
|
for key, value in results.items():
|
|
story.append(Paragraph(key.capitalize(), styles['Heading1']))
|
|
story.append(Paragraph(value, styles['BodyText']))
|
|
story.append(Spacer(1, 12))
|
|
|
|
doc.build(story)
|
|
return tmp.name
|
|
|
|
|
|
if __name__ == "__main__":
|
|
iface = create_gradio_interface()
|
|
iface.launch()
|
|
|