|
import gradio as gr |
|
from typing import Dict |
|
import logging |
|
import tempfile |
|
import io |
|
import torch |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
from pdf2image import convert_from_bytes |
|
from PIL import Image |
|
import pytesseract |
|
import docx2txt |
|
from reportlab.lib.pagesizes import letter |
|
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer |
|
from reportlab.lib.styles import getSampleStyleSheet |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError |
|
import docx |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
class AdvancedDocProcessor: |
|
def __init__(self): |
|
|
|
self.bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
self.bart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn", torch_dtype=torch.float32) |
|
|
|
|
|
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") |
|
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", torch_dtype=torch.float32) |
|
|
|
|
|
self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", torch_dtype=torch.float32) |
|
|
|
def extract_text(self, file_content: bytes, file_type: str) -> str: |
|
"""Extract text from various file types.""" |
|
try: |
|
if file_type == "application/pdf": |
|
return self.extract_text_from_pdf(file_content) |
|
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": |
|
return self.extract_text_from_docx(file_content) |
|
elif file_type == "text/plain": |
|
return file_content.decode('utf-8') |
|
else: |
|
raise ValueError(f"Unsupported file type: {file_type}") |
|
except Exception as e: |
|
logger.error(f"Error extracting text: {str(e)}") |
|
return "" |
|
|
|
def extract_text_from_pdf(self, pdf_content: bytes) -> str: |
|
"""Extract text from PDF using OCR.""" |
|
try: |
|
images = convert_from_bytes(pdf_content, timeout=60) |
|
text = "" |
|
for image in images: |
|
text += pytesseract.image_to_string(image) |
|
return text |
|
except Exception as e: |
|
logger.error(f"Error extracting text from PDF: {str(e)}") |
|
return "" |
|
|
|
def extract_text_from_docx(self, docx_content: bytes) -> str: |
|
"""Extract text from a DOCX file.""" |
|
try: |
|
return docx2txt.process(io.BytesIO(docx_content)) |
|
except Exception as e: |
|
logger.error(f"Error extracting text from DOCX: {str(e)}") |
|
return "" |
|
|
|
def clean_and_summarize_text(self, text: str) -> str: |
|
"""Clean and summarize the text using BART.""" |
|
try: |
|
chunk_size = 1024 |
|
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] |
|
summarized_chunks = [] |
|
for chunk in chunks: |
|
inputs = self.bart_tokenizer([chunk], max_length=1024, return_tensors="pt", truncation=True) |
|
summary_ids = self.bart_model.generate(inputs["input_ids"], num_beams=4, max_length=150, early_stopping=True) |
|
summarized_chunks.append(self.bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)) |
|
return " ".join(summarized_chunks) |
|
except Exception as e: |
|
logger.error(f"Error cleaning and summarizing text: {str(e)}") |
|
return text |
|
|
|
def process_with_t5(self, text: str, prompt: str) -> str: |
|
"""Process the text with T5 based on the given prompt.""" |
|
try: |
|
chunk_size = 512 |
|
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] |
|
processed_chunks = [] |
|
for chunk in chunks: |
|
input_text = f"{prompt} {chunk}" |
|
inputs = self.t5_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = self.t5_model.generate( |
|
**inputs, |
|
max_length=150, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=0.7 |
|
) |
|
processed_chunks.append(self.t5_tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
return " ".join(processed_chunks) |
|
except Exception as e: |
|
logger.error(f"Error processing with T5: {str(e)}") |
|
return f"Error processing text: {str(e)}" |
|
|
|
def extract_entities(self, text: str) -> str: |
|
"""Extract named entities from the text.""" |
|
try: |
|
chunk_size = 10000 |
|
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] |
|
all_entities = [] |
|
for chunk in chunks: |
|
entities = self.ner_pipeline(chunk) |
|
all_entities.extend(entities) |
|
unique_entities = set((ent['word'], ent['entity']) for ent in all_entities) |
|
return "\n".join([f"{word} ({entity})" for word, entity in unique_entities]) |
|
except Exception as e: |
|
logger.error(f"Error extracting entities: {str(e)}") |
|
return "Error extracting entities" |
|
|
|
def process_document(self, file_content: bytes, file_type: str, prompt: str) -> Dict[str, str]: |
|
raw_text = self.extract_text(file_content, file_type) |
|
cleaned_text = self.clean_and_summarize_text(raw_text) |
|
processed_text = self.process_with_t5(cleaned_text, prompt) |
|
entities = self.extract_entities(raw_text) |
|
|
|
return { |
|
"cleaned": cleaned_text, |
|
"processed": processed_text, |
|
"entities": entities |
|
} |
|
|
|
def create_gradio_interface(): |
|
processor = AdvancedDocProcessor() |
|
|
|
def process_and_display(file, prompt, output_format): |
|
def processing_task(): |
|
if isinstance(file, str): |
|
with open(file, 'rb') as f: |
|
file_content = f.read() |
|
else: |
|
file_content = file |
|
|
|
file_type = infer_file_type(file_content) |
|
results = processor.process_document(file_content, file_type, prompt) |
|
|
|
if output_format == "txt": |
|
output_path = save_as_txt(results) |
|
elif output_format == "docx": |
|
output_path = save_as_docx(results) |
|
else: |
|
output_path = save_as_pdf(results) |
|
|
|
return (f"Cleaned and Summarized Text:\n{results['cleaned']}\n\n" |
|
f"Processed Text:\n{results['processed']}\n\n" |
|
f"Extracted Entities:\n{results['entities']}"), output_path |
|
|
|
with ThreadPoolExecutor() as executor: |
|
future = executor.submit(processing_task) |
|
try: |
|
return future.result(timeout=300) |
|
except TimeoutError: |
|
return "Processing timed out after 5 minutes.", None |
|
except Exception as e: |
|
logger.error(f"Error during processing: {str(e)}") |
|
return f"An error occurred during processing: {str(e)}", None |
|
|
|
iface = gr.Interface( |
|
fn=process_and_display, |
|
inputs=[ |
|
gr.File(label="Upload Document (PDF, DOCX, or TXT)"), |
|
gr.Textbox(label="Enter your prompt for processing", lines=3), |
|
gr.Radio(["txt", "docx", "pdf"], label="Output Format", value="txt") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Processing Results", lines=30), |
|
gr.File(label="Download Processed Document") |
|
], |
|
title="Advanced Document Processing Tool", |
|
description="Upload a document (PDF, DOCX, or TXT) and enter a prompt to process and analyze the text using state-of-the-art NLP models.", |
|
) |
|
|
|
return iface |
|
|
|
def infer_file_type(file_content: bytes) -> str: |
|
"""Infer the file type from the byte content.""" |
|
if file_content.startswith(b'%PDF'): |
|
return "application/pdf" |
|
elif file_content.startswith(b'PK\x03\x04'): |
|
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document" |
|
else: |
|
return "text/plain" |
|
|
|
def save_as_txt(results: Dict[str, str]) -> str: |
|
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix='.txt') as temp_file: |
|
for key, value in results.items(): |
|
temp_file.write(f"{key.upper()}:\n{value}\n\n") |
|
return temp_file.name |
|
|
|
def save_as_docx(results: Dict[str, str]) -> str: |
|
doc = docx.Document() |
|
for key, value in results.items(): |
|
doc.add_heading(key.capitalize(), level=1) |
|
doc.add_paragraph(value) |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp: |
|
doc.save(tmp.name) |
|
return tmp.name |
|
|
|
def save_as_pdf(results: Dict[str, str]) -> str: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp: |
|
doc = SimpleDocTemplate(tmp.name, pagesize=letter) |
|
styles = getSampleStyleSheet() |
|
story = [] |
|
|
|
for key, value in results.items(): |
|
story.append(Paragraph(key.capitalize(), styles['Heading1'])) |
|
story.append(Paragraph(value, styles['BodyText'])) |
|
story.append(Spacer(1, 12)) |
|
|
|
doc.build(story) |
|
return tmp.name |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
np.set_printoptions(legacy='1.13') |
|
|
|
iface = create_gradio_interface() |
|
iface.launch() |