autofinance-us / app.py
fsommers's picture
Test updates
b25b478 unverified
raw
history blame
4.71 kB
import numpy as np
import pandas as pd
import streamlit as st
from PIL import Image
import torch
import torch.nn.functional as F
import pytesseract
import plotly.express as px
from torch.utils.data import Dataset, DataLoader, Subset
import os
import io
import pytesseract
import fitz
from typing import List
import json
import sys
from pathlib import Path
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOKENIZER = "microsoft/layoutlmv3-base"
MODEL_NAME = "fsommers/layoutlmv3-autofinance-classification-us-v01"
TESS_OPTIONS = "--psm 3" # Automatic page segmentation for Tesseract
@st.cache_resource
def create_ocr_reader():
def scale_bounding_box(box: List[int], w_scale: float = 1.0, h_scale: float = 1.0):
return [
int(box[0] * w_scale),
int(box[1] * h_scale),
int(box[2] * w_scale),
int(box[3] * h_scale)
]
def ocr_page(image) -> dict:
"""
OCR a given image. Return a dictionary of words and the bounding boxes
for each word. For each word, there is a corresponding bounding box.
"""
ocr_df = pytesseract.image_to_data(image, output_type='data.frame', config=TESS_OPTIONS)
ocr_df = ocr_df.dropna().reset_index(drop=True)
float_cols = ocr_df.select_dtypes('float').columns
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
ocr_df = ocr_df.dropna().reset_index(drop=True)
words = list(ocr_df.text)
words = [str(w) for w in words]
coordinates = ocr_df[['left', 'top', 'width', 'height']]
boxes = []
for i, row in coordinates.iterrows():
x, y, w, h = tuple(row)
actual_box = [x, y, x + w, y + h]
boxes.append(actual_box)
assert len(words) == len(boxes)
return {"bbox": boxes, "words": words}
def prepare_image(image):
ocr_data = ocr_page(image)
width, height = image.size
width_scale = 1000 / width
height_scale = 1000 / height
words = []
boxes = []
for w, b in zip(ocr_data["words"], ocr_data["bbox"]):
words.append(w)
boxes.append(scale_bounding_box(b, width_scale, height_scale))
assert len(words) == len(boxes)
for bo in boxes:
for z in bo:
if (z > 1000):
raise
return words, boxes
return prepare_image
@st.cache_resource
def create_model():
model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
return model.eval().to(DEVICE)
@st.cache_resource
def create_processor():
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(TOKENIZER)
return LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
def predict(image, reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification):
words, boxes = reader(image)
encoding = processor(
image,
words,
boxes=boxes,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
with torch.inference_mode():
output = model(
input_ids=encoding["input_ids"].to(DEVICE),
attention_mask=encoding["attention_mask"].to(DEVICE),
bbox=encoding["bbox"].to(DEVICE),
pixel_values=encoding["pixel_values"].to(DEVICE)
)
logits = output.logits
predicted_class = logits.argmax()
probabilities = F.softmax(logits, dim=-1).flatten().tolist()
return predicted_class.detach().item(), probabilities
st.markdown(f"Test")
# reader = create_ocr_reader()
# processor = create_processor()
# model = create_model()
# uploaded_file = st.file_uploader("Choose a JPG file", ["jpg", "png"])
# if uploaded_file is not None:
# bytes_data = io.BytesIO(uploaded_file.read())
# image = Image.open(bytes_data)
# st.image(image, caption="Uploaded Image", use_column_width=True)
# predicted, probabilities = predict(image, reader, processor, model)
# predicted_label = model.config.id2label[predicted]
# st.markdown(f"Predicted Label: {predicted_label}")
# df = pd.DataFrame({
# "Label": list(model.config.id2label.values()),
# "Probability": probabilities
# })
# fig = px.bar(df, x="Label", y="Probability")
# st.plotly_chart(fig, use_container_width=True)