Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import logging
|
|
4 |
import tempfile
|
5 |
import io
|
6 |
import torch
|
|
|
7 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
8 |
from pdf2image import convert_from_bytes
|
9 |
from PIL import Image
|
@@ -24,14 +25,14 @@ class AdvancedDocProcessor:
|
|
24 |
def __init__(self):
|
25 |
# Initialize BART model for text cleaning and summarization
|
26 |
self.bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
|
27 |
-
self.bart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
|
28 |
|
29 |
# Initialize T5 model for text generation tasks
|
30 |
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
31 |
-
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
|
32 |
|
33 |
# Initialize pipeline for named entity recognition
|
34 |
-
self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
|
35 |
|
36 |
def extract_text(self, file_content: bytes, file_type: str) -> str:
|
37 |
"""Extract text from various file types."""
|
@@ -137,7 +138,12 @@ def create_gradio_interface():
|
|
137 |
|
138 |
def process_and_display(file, prompt, output_format):
|
139 |
def processing_task():
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
141 |
file_type = infer_file_type(file_content)
|
142 |
results = processor.process_document(file_content, file_type, prompt)
|
143 |
|
@@ -158,6 +164,9 @@ def create_gradio_interface():
|
|
158 |
return future.result(timeout=300) # 5 minutes timeout
|
159 |
except TimeoutError:
|
160 |
return "Processing timed out after 5 minutes.", None
|
|
|
|
|
|
|
161 |
|
162 |
iface = gr.Interface(
|
163 |
fn=process_and_display,
|
@@ -217,5 +226,8 @@ def save_as_pdf(results: Dict[str, str]) -> str:
|
|
217 |
|
218 |
# Launch the Gradio app
|
219 |
if __name__ == "__main__":
|
|
|
|
|
|
|
220 |
iface = create_gradio_interface()
|
221 |
iface.launch()
|
|
|
4 |
import tempfile
|
5 |
import io
|
6 |
import torch
|
7 |
+
import numpy as np
|
8 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
9 |
from pdf2image import convert_from_bytes
|
10 |
from PIL import Image
|
|
|
25 |
def __init__(self):
|
26 |
# Initialize BART model for text cleaning and summarization
|
27 |
self.bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
|
28 |
+
self.bart_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn", torch_dtype=torch.float32)
|
29 |
|
30 |
# Initialize T5 model for text generation tasks
|
31 |
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
32 |
+
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", torch_dtype=torch.float32)
|
33 |
|
34 |
# Initialize pipeline for named entity recognition
|
35 |
+
self.ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", torch_dtype=torch.float32)
|
36 |
|
37 |
def extract_text(self, file_content: bytes, file_type: str) -> str:
|
38 |
"""Extract text from various file types."""
|
|
|
138 |
|
139 |
def process_and_display(file, prompt, output_format):
|
140 |
def processing_task():
|
141 |
+
if isinstance(file, str): # If it's a file path
|
142 |
+
with open(file, 'rb') as f:
|
143 |
+
file_content = f.read()
|
144 |
+
else: # If it's already file content
|
145 |
+
file_content = file
|
146 |
+
|
147 |
file_type = infer_file_type(file_content)
|
148 |
results = processor.process_document(file_content, file_type, prompt)
|
149 |
|
|
|
164 |
return future.result(timeout=300) # 5 minutes timeout
|
165 |
except TimeoutError:
|
166 |
return "Processing timed out after 5 minutes.", None
|
167 |
+
except Exception as e:
|
168 |
+
logger.error(f"Error during processing: {str(e)}")
|
169 |
+
return f"An error occurred during processing: {str(e)}", None
|
170 |
|
171 |
iface = gr.Interface(
|
172 |
fn=process_and_display,
|
|
|
226 |
|
227 |
# Launch the Gradio app
|
228 |
if __name__ == "__main__":
|
229 |
+
# Set NumPy print options to avoid warnings
|
230 |
+
np.set_printoptions(legacy='1.13')
|
231 |
+
|
232 |
iface = create_gradio_interface()
|
233 |
iface.launch()
|