kh-CHEUNG's picture
Update app.py
6250640 verified
raw
history blame
5.09 kB
import numpy as np
import re
import streamlit as st
import torch
from transformers import AutoProcessor, UdopForConditionalGeneration
from PIL import Image, ImageDraw
# from datasets import load_dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# UDOP uses 501 special loc ("location") tokens
LAYOUT_VOCAB_SIZE = 501
def extract_coordinates(string):
# Using regular expression to find all numbers in the string
numbers = re.findall(r'\d+', string)
# Converting the numbers to integers
numbers = list(map(int, numbers))
# Ensuring there are exactly 4 numbers
if len(numbers) >= 4: #if len(numbers) != 4:
numbers = numbers[-4:]
# Extracting coordinates
x1, y1, x2, y2 = numbers
else:
return []
return [x1, y1, x2, y2]
def unnormalize_box(box, image_width, image_height):
x1 = box[0] / LAYOUT_VOCAB_SIZE * image_width
y1 = box[1] / LAYOUT_VOCAB_SIZE * image_height
x2 = box[2] / LAYOUT_VOCAB_SIZE * image_width
y2 = box[3] / LAYOUT_VOCAB_SIZE * image_height
return [x1, y1, x2, y2]
processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=True)
model = UdopForConditionalGeneration.from_pretrained("microsoft/udop-large")
st.title("GenAI Demo (by ITT)")
st.text("Upload and Select a document (/an image) to test the model.")
#2 column layout
col1, col2 = st.columns(2)
with col1:
# File selection
uploaded_files = st.file_uploader("Upload document(s) [/image(s)]:", type=["docx", "pdf", "pptx", "jpg", "jpeg", "png"], accept_multiple_files=True, key="fileUpload")
selected_file = st.selectbox("Select a document (/an image):", uploaded_files, format_func=lambda file: file.name if file else "None", key="fileSelect")
# Display selected file
if selected_file is not None and selected_file != "None":
file_extension = selected_file.name.split(".")[-1]
if file_extension in ["jpg", "jpeg", "png"]:
image = Image.open(selected_file).convert("RGB")
st.image(selected_file, caption="Selected Image")
else:
st.write("Selected file: ", selected_file.name)
# Model Testing
with col2:
## Question (/Prompt)
# question = "Question answering. How many unsafe practice of Lifting Operation?"
default_question = "Is this a Lifting Operation scene?"
task_type = st.selectbox("Question Type:", ("Classification", "Question Answering", "Layout Analysis"), index=1, key="taskSelect")
question_text = st.text_area("Prompt:", placeholder=default_question, key="questionInput")
if question_text is not None:
question = task_type + ". " + question_text
else:
question = task_type + ". " + default_question
## Test button
testButton = st.button("Test Model", key="testStart")
## Perform Model Testing when Image is uploaded and selected as well as Test button is pressed
if testButton and selected_file != "None":
st.write("Testing the model with the selected image...")
# encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
model_encoding = processor(images=image, text=question, return_tensors="pt")
model_output = model.generate(**model_encoding)
match task_type:
case "Classification":
output_text = processor.batch_decode(model_output, skip_special_tokens=True)[0]
st.write(output_text)
case "Question Answering":
output_text = processor.batch_decode(model_output, skip_special_tokens=True)[0]
st.write(output_text)
case "Layout Analysis":
output_text = processor.batch_decode(model_output, skip_special_tokens=False)[0]
mean = processor.image_processor.image_mean
std = processor.image_processor.image_std
unnormalized_image = (model_encoding.pixel_values.squeeze().numpy() * np.array(std)[:, None, None]) + np.array(mean)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
unnormalized_image = Image.fromarray(unnormalized_image)
# Get the coordinates from the output text and denormalize them
coordinates = extract_coordinates(output_text)
if coordinates:
coordinates = unnormalize_box(coordinates, unnormalized_image.width, unnormalized_image.height)
draw = ImageDraw.Draw(unnormalized_image)
draw.rectangle(coordinates, outline="red")
st.image(unnormalized_image, caption="Output Image")
else:
st.write("Cannot obtain Bounding Box coordinates: " + output_text)
elif testButton and selected_file == "None":
st.write("Please upload and select a document (/an image).")