sekisan-app / app.py
vumichien's picture
Update app.py
ab921ed verified
raw
history blame
14.3 kB
from ultralytics import YOLO
import supervision as sv
import cv2
import gradio as gr
import os
import numpy as np
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
import requests
from PIL import Image
import glob
import pandas as pd
import time
from pdf2image import convert_from_path
import pymupdf
import camelot
import numpy as np
import fitz
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True).to(device).eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
onnx_model = YOLO("models/best.onnx", task='detect')
onnx_model_table = YOLO("models/tables/best.onnx", task='detect')
def filter_detections(detections, target_class_name="mark"):
indices_to_keep = [i for i, class_name in enumerate(detections.data['class_name']) if
class_name == target_class_name]
filtered_xyxy = detections.xyxy[indices_to_keep]
filtered_confidence = detections.confidence[indices_to_keep]
filtered_class_id = detections.class_id[indices_to_keep]
filtered_class_name = detections.data['class_name'][indices_to_keep]
detections.xyxy = filtered_xyxy
detections.confidence = filtered_confidence
detections.class_id = filtered_class_id
detections.data['class_name'] = filtered_class_name
return detections
def add_label_detection(detections):
updated_class = [f"{class_name} {i + 1}" for i, class_name in enumerate(detections.data['class_name'])]
updated_id = [class_id + i for i, class_id in enumerate(detections.class_id)]
detections.data['class_name'] = np.array(updated_class)
detections.class_id = np.array(updated_id)
return detections
def ends_with_number(s):
return s[-1].isdigit()
def ocr(image, prompt="<OCR>"):
original_height, original_width = image.shape[:2]
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=prompt,
# image_size=(image.width, image.height)
image_size=(original_width, original_height)
)
return parsed_answer
def parse_detection(detections):
parsed_rows = []
for i in range(len(detections.xyxy)):
x_min = float(detections.xyxy[i][0])
y_min = float(detections.xyxy[i][1])
x_max = float(detections.xyxy[i][2])
y_max = float(detections.xyxy[i][3])
width = int(x_max - x_min)
height = int(y_max - y_min)
row = {
"top": int(y_min),
"left": int(x_min),
"width": width,
"height": height,
"class_id": ""
if detections.class_id is None
else int(detections.class_id[i]),
"confidence": ""
if detections.confidence is None
else float(detections.confidence[i]),
"tracker_id": ""
if detections.tracker_id is None
else int(detections.tracker_id[i]),
}
if hasattr(detections, "data"):
for key, value in detections.data.items():
row[key] = (
str(value[i])
if hasattr(value, "__getitem__") and value.ndim != 0
else str(value)
)
parsed_rows.append(row)
return parsed_rows
def cut_and_save_image(image, parsed_detections, output_dir):
output_path_list = []
for i, det in enumerate(parsed_detections):
# Check if the class is 'mark'
if det['class_name'] == 'mark':
top = det['top']
left = det['left']
width = det['width']
height = det['height']
# Cut the image
cut_image = image[top:top + height, left:left + width]
# Save the image
output_path = f"{output_dir}/cut_image_{i}.png"
scaled_image = sv.scale_image(image=cut_image, scale_factor=4)
cv2.imwrite(output_path, scaled_image, [int(cv2.IMWRITE_JPEG_QUALITY), 500])
output_path_list.append(output_path)
return output_path_list
def analysis(progress=gr.Progress()):
progress(0, desc="Analyzing...")
list_files = glob.glob("output/*.png")
prompt = "<OCR>"
results = {}
for filepath in progress.tqdm(list_files):
basename = os.path.basename(filepath)
image = cv2.imread(filepath)
start_time = time.time()
parsed_answer = ocr(image, prompt)
if not ends_with_number(parsed_answer[prompt]):
parsed_answer[prompt] += "1"
results[parsed_answer[prompt]] = results.get(parsed_answer[prompt], 0) + 1
print(basename, parsed_answer[prompt])
print("Time taken:", time.time() - start_time)
return pd.DataFrame(results.items(), columns=['Mark', 'Total']).reset_index(drop=False).rename(columns={'index': 'No.'})
def inference(
image_path,
conf_threshold,
iou_threshold,
):
"""
YOLOv8 inference function
Args:
image_path: Path to the image
conf_threshold: Confidence threshold
iou_threshold: IoU threshold
Returns:
Rendered image
"""
image = cv2.imread(image_path)
original_height, original_width = image.shape[:2]
print(image.shape)
results = onnx_model(image, conf=conf_threshold, iou=iou_threshold)[0]
detections = sv.Detections.from_ultralytics(results)
detections = filter_detections(detections)
parsed_detections = parse_detection(detections)
output_dir = "output"
# Check if the output directory exists, clear all the files inside
if not os.path.exists(output_dir):
os.makedirs(output_dir)
else:
for f in os.listdir(output_dir):
os.remove(os.path.join(output_dir, f))
output_path_list = cut_and_save_image(image, parsed_detections, output_dir)
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2)
annotated_image = image.copy()
annotated_image = box_annotator.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
return annotated_image, output_path_list
def read_table(sheet):
excel_path = "output_tables.xlsx"
if os.path.exists(excel_path):
sheetnames = pd.ExcelFile(excel_path).sheet_names
if sheet in sheetnames:
df = pd.read_excel(excel_path, sheet_name=sheet)
else:
df = pd.DataFrame()
else:
df = pd.DataFrame()
return df
def validate_df(df):
columns = []
count = 1
for col in df.columns:
if type(col) == int:
columns.append(f"Col {count}")
count += 1
else:
columns.append(col)
df.columns = columns
return df
def analyze_table(file, conf_threshold, iou_threshold, progress=gr.Progress()):
progress(0, desc="Parsing table...")
img = convert_from_path(file)[0]
doc = pymupdf.open(file)
zoom_x = 1.0 # horizontal zoom
zoom_y = 1.0 # vertical zoom
mat = pymupdf.Matrix(zoom_x, zoom_y)
for i, page in enumerate(doc):
pix = page.get_pixmap(matrix=mat)
pix.save("temp.png")
image = cv2.imread("temp.png")
file_height, file_width, _ = image.shape
results = onnx_model_table(image, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0]
detections = sv.Detections.from_ultralytics(results)
detections = add_label_detection(detections)
parsed_detections = parse_detection(detections)
# print(parsed_detections)
output_dir = "output_table"
# Check if the output directory exists, clear all the files inside
if not os.path.exists(output_dir):
os.makedirs(output_dir)
else:
for f in os.listdir(output_dir):
os.remove(os.path.join(output_dir, f))
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2)
annotated_image = image.copy()
annotated_image = box_annotator.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
pdf = fitz.open(file)
pdf_page = pdf[0]
table_area = [(ind,
fitz.Rect(det['left'], det['top'], det['left'] + det['width'], det['top'] + det['height']))
for ind, det in enumerate(parsed_detections)
]
table_list = []
for ind, area in progress.tqdm(table_area):
pdf_tabs = pdf_page.find_tables(clip=area)
if len(pdf_tabs.tables) > 0:
pdf_df = pdf_tabs[0].to_pandas()
print("Fitz Table Found!")
else:
cur = parsed_detections[ind]
table_areas = [f"{cur['left']},{file_height - cur['top']},{cur['left'] + cur['width']},{file_height - (cur['top'] + cur['height'])}"]
tables = camelot.read_pdf(file, pages='0', flavor='stream', row_tol=10, table_areas=table_areas)
pdf_df = tables[0].df
print("Camelot Table Found!")
pdf_df = validate_df(pdf_df)
table_list.append(pdf_df)
excel_path = "output_tables.xlsx"
sheet_list = []
with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer:
for i in range(len(table_list)):
sheet_name = f"Table_{i + 1}"
table_list[i].to_excel(writer, sheet_name=sheet_name, index=False)
sheet_list.append(sheet_name)
return img, annotated_image, excel_path, ", ".join(sheet_list)
TITLE = "<h1 style='font-size: 2.5em; text-align: center;'>Identify objects in construction design</h1>"
DESCRIPTION = """<p style='font-size: 1.5em; line-height: 1.6em; text-align: left;'>Welcome to the object
identification application. This tool allows you to upload an image, and it will identify and annotate objects within
the image. Additionally, you can perform OCR analysis on the detected objects.</p>
"""
CSS = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
h1 {
text-align: center;
}
"""
EXAMPLES = [
['examples/train1.png', 0.6, 0.25],
['examples/train2.png', 0.9, 0.25],
['examples/train3.png', 0.6, 0.25]
]
SHEET_LIST = ['Table_1', 'Table_2', 'Table_3', 'Table_4', 'Table_5', 'Table_6']
with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Tab(label="Identify objects"):
with gr.Row(equal_height=False):
input_img = gr.Image(type="filepath", label="Upload Image")
output_img = gr.Image(type="filepath", label="Output Image")
with gr.Row():
with gr.Column():
conf_thres = gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold")
with gr.Column():
iou = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold")
with gr.Row():
with gr.Column():
submit_btn = gr.Button(value="Predict")
with gr.Column():
analysis_btn = gr.Button(value="Analysis")
with gr.Row():
output_df = gr.Dataframe(label="Results")
with gr.Row():
with gr.Accordion("Gallery", open=False):
gallery = gr.Gallery(label="Detected Mark Object", columns=3)
submit_btn.click(inference, [input_img, conf_thres, iou], [output_img, gallery])
analysis_btn.click(analysis, [], [output_df])
examples = gr.Examples(
EXAMPLES,
fn=inference,
inputs=[input_img, conf_thres, iou],
outputs=[output_img, gallery],
cache_examples=False,
)
with gr.Tab(label="Detect and read table"):
with gr.Row():
with gr.Column():
upload_pdf = gr.Image(label="Upload PDF file")
upload_button = gr.UploadButton(label="Upload PDF file", file_types=[".pdf"])
with gr.Column():
output_img = gr.Image(label="Output Image", interactive=False)
with gr.Row():
with gr.Column():
conf_thres_table = gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05,
label="Confidence Threshold")
with gr.Column():
iou_table = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold")
with gr.Row():
with gr.Column():
text_output = gr.Textbox(label="Table List")
with gr.Column():
file_output = gr.File()
with gr.Row():
sheet_name = gr.Dropdown(choices=SHEET_LIST, allow_custom_value=True, label="Sheet Name")
with gr.Row():
output_df = gr.Dataframe(label="Results")
upload_button.upload(analyze_table, [upload_button, conf_thres_table, iou_table],
[upload_pdf, output_img, file_output, text_output])
conf_thres_table.change(analyze_table, [upload_button, conf_thres_table, iou_table],
[upload_pdf, output_img, file_output, text_output])
iou_table.change(analyze_table, [upload_button, conf_thres_table, iou_table],
[upload_pdf, output_img, file_output, text_output])
sheet_name.change(read_table, sheet_name, output_df)
demo.launch(debug=True)