Spaces:
Runtime error
Runtime error
from ultralytics import YOLO | |
import supervision as sv | |
import cv2 | |
import gradio as gr | |
import os | |
import numpy as np | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import torch | |
import requests | |
from PIL import Image | |
import glob | |
import pandas as pd | |
import time | |
from pdf2image import convert_from_path | |
import pymupdf | |
import camelot | |
import numpy as np | |
import fitz | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True).to(device).eval() | |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True) | |
onnx_model = YOLO("models/best.onnx", task='detect') | |
onnx_model_table = YOLO("models/tables/best.onnx", task='detect') | |
def filter_detections(detections, target_class_name="mark"): | |
indices_to_keep = [i for i, class_name in enumerate(detections.data['class_name']) if | |
class_name == target_class_name] | |
filtered_xyxy = detections.xyxy[indices_to_keep] | |
filtered_confidence = detections.confidence[indices_to_keep] | |
filtered_class_id = detections.class_id[indices_to_keep] | |
filtered_class_name = detections.data['class_name'][indices_to_keep] | |
detections.xyxy = filtered_xyxy | |
detections.confidence = filtered_confidence | |
detections.class_id = filtered_class_id | |
detections.data['class_name'] = filtered_class_name | |
return detections | |
def add_label_detection(detections): | |
updated_class = [f"{class_name} {i + 1}" for i, class_name in enumerate(detections.data['class_name'])] | |
updated_id = [class_id + i for i, class_id in enumerate(detections.class_id)] | |
detections.data['class_name'] = np.array(updated_class) | |
detections.class_id = np.array(updated_id) | |
return detections | |
def ends_with_number(s): | |
return s[-1].isdigit() | |
def ocr(image, prompt="<OCR>"): | |
original_height, original_width = image.shape[:2] | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, | |
task=prompt, | |
# image_size=(image.width, image.height) | |
image_size=(original_width, original_height) | |
) | |
return parsed_answer | |
def parse_detection(detections): | |
parsed_rows = [] | |
for i in range(len(detections.xyxy)): | |
x_min = float(detections.xyxy[i][0]) | |
y_min = float(detections.xyxy[i][1]) | |
x_max = float(detections.xyxy[i][2]) | |
y_max = float(detections.xyxy[i][3]) | |
width = int(x_max - x_min) | |
height = int(y_max - y_min) | |
row = { | |
"top": int(y_min), | |
"left": int(x_min), | |
"width": width, | |
"height": height, | |
"class_id": "" | |
if detections.class_id is None | |
else int(detections.class_id[i]), | |
"confidence": "" | |
if detections.confidence is None | |
else float(detections.confidence[i]), | |
"tracker_id": "" | |
if detections.tracker_id is None | |
else int(detections.tracker_id[i]), | |
} | |
if hasattr(detections, "data"): | |
for key, value in detections.data.items(): | |
row[key] = ( | |
str(value[i]) | |
if hasattr(value, "__getitem__") and value.ndim != 0 | |
else str(value) | |
) | |
parsed_rows.append(row) | |
return parsed_rows | |
def cut_and_save_image(image, parsed_detections, output_dir): | |
output_path_list = [] | |
for i, det in enumerate(parsed_detections): | |
# Check if the class is 'mark' | |
if det['class_name'] == 'mark': | |
top = det['top'] | |
left = det['left'] | |
width = det['width'] | |
height = det['height'] | |
# Cut the image | |
cut_image = image[top:top + height, left:left + width] | |
# Save the image | |
output_path = f"{output_dir}/cut_image_{i}.png" | |
scaled_image = sv.scale_image(image=cut_image, scale_factor=4) | |
cv2.imwrite(output_path, scaled_image, [int(cv2.IMWRITE_JPEG_QUALITY), 500]) | |
output_path_list.append(output_path) | |
return output_path_list | |
def analysis(progress=gr.Progress()): | |
progress(0, desc="Analyzing...") | |
list_files = glob.glob("output/*.png") | |
prompt = "<OCR>" | |
results = {} | |
for filepath in progress.tqdm(list_files): | |
basename = os.path.basename(filepath) | |
image = cv2.imread(filepath) | |
start_time = time.time() | |
parsed_answer = ocr(image, prompt) | |
if not ends_with_number(parsed_answer[prompt]): | |
parsed_answer[prompt] += "1" | |
results[parsed_answer[prompt]] = results.get(parsed_answer[prompt], 0) + 1 | |
print(basename, parsed_answer[prompt]) | |
print("Time taken:", time.time() - start_time) | |
return pd.DataFrame(results.items(), columns=['Mark', 'Total']).reset_index(drop=False).rename(columns={'index': 'No.'}) | |
def inference( | |
image_path, | |
conf_threshold, | |
iou_threshold, | |
): | |
""" | |
YOLOv8 inference function | |
Args: | |
image_path: Path to the image | |
conf_threshold: Confidence threshold | |
iou_threshold: IoU threshold | |
Returns: | |
Rendered image | |
""" | |
image = cv2.imread(image_path) | |
original_height, original_width = image.shape[:2] | |
print(image.shape) | |
results = onnx_model(image, conf=conf_threshold, iou=iou_threshold)[0] | |
detections = sv.Detections.from_ultralytics(results) | |
detections = filter_detections(detections) | |
parsed_detections = parse_detection(detections) | |
output_dir = "output" | |
# Check if the output directory exists, clear all the files inside | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
else: | |
for f in os.listdir(output_dir): | |
os.remove(os.path.join(output_dir, f)) | |
output_path_list = cut_and_save_image(image, parsed_detections, output_dir) | |
box_annotator = sv.BoxAnnotator() | |
label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2) | |
annotated_image = image.copy() | |
annotated_image = box_annotator.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections) | |
return annotated_image, output_path_list | |
def read_table(sheet): | |
excel_path = "output_tables.xlsx" | |
if os.path.exists(excel_path): | |
sheetnames = pd.ExcelFile(excel_path).sheet_names | |
if sheet in sheetnames: | |
df = pd.read_excel(excel_path, sheet_name=sheet) | |
else: | |
df = pd.DataFrame() | |
else: | |
df = pd.DataFrame() | |
return df | |
def validate_df(df): | |
columns = [] | |
count = 1 | |
for col in df.columns: | |
if type(col) == int: | |
columns.append(f"Col {count}") | |
count += 1 | |
else: | |
columns.append(col) | |
df.columns = columns | |
return df | |
def analyze_table(file, conf_threshold, iou_threshold, progress=gr.Progress()): | |
progress(0, desc="Parsing table...") | |
img = convert_from_path(file)[0] | |
doc = pymupdf.open(file) | |
zoom_x = 1.0 # horizontal zoom | |
zoom_y = 1.0 # vertical zoom | |
mat = pymupdf.Matrix(zoom_x, zoom_y) | |
for i, page in enumerate(doc): | |
pix = page.get_pixmap(matrix=mat) | |
pix.save("temp.png") | |
image = cv2.imread("temp.png") | |
file_height, file_width, _ = image.shape | |
results = onnx_model_table(image, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0] | |
detections = sv.Detections.from_ultralytics(results) | |
detections = add_label_detection(detections) | |
parsed_detections = parse_detection(detections) | |
# print(parsed_detections) | |
output_dir = "output_table" | |
# Check if the output directory exists, clear all the files inside | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
else: | |
for f in os.listdir(output_dir): | |
os.remove(os.path.join(output_dir, f)) | |
box_annotator = sv.BoxAnnotator() | |
label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2) | |
annotated_image = image.copy() | |
annotated_image = box_annotator.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections) | |
pdf = fitz.open(file) | |
pdf_page = pdf[0] | |
table_area = [(ind, | |
fitz.Rect(det['left'], det['top'], det['left'] + det['width'], det['top'] + det['height'])) | |
for ind, det in enumerate(parsed_detections) | |
] | |
table_list = [] | |
for ind, area in progress.tqdm(table_area): | |
pdf_tabs = pdf_page.find_tables(clip=area) | |
if len(pdf_tabs.tables) > 0: | |
pdf_df = pdf_tabs[0].to_pandas() | |
print("Fitz Table Found!") | |
else: | |
cur = parsed_detections[ind] | |
table_areas = [f"{cur['left']},{file_height - cur['top']},{cur['left'] + cur['width']},{file_height - (cur['top'] + cur['height'])}"] | |
tables = camelot.read_pdf(file, pages='0', flavor='stream', row_tol=10, table_areas=table_areas) | |
pdf_df = tables[0].df | |
print("Camelot Table Found!") | |
pdf_df = validate_df(pdf_df) | |
table_list.append(pdf_df) | |
excel_path = "output_tables.xlsx" | |
sheet_list = [] | |
with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer: | |
for i in range(len(table_list)): | |
sheet_name = f"Table_{i + 1}" | |
table_list[i].to_excel(writer, sheet_name=sheet_name, index=False) | |
sheet_list.append(sheet_name) | |
return img, annotated_image, excel_path, ", ".join(sheet_list) | |
TITLE = "<h1 style='font-size: 2.5em; text-align: center;'>Identify objects in construction design</h1>" | |
DESCRIPTION = """<p style='font-size: 1.5em; line-height: 1.6em; text-align: left;'>Welcome to the object | |
identification application. This tool allows you to upload an image, and it will identify and annotate objects within | |
the image. Additionally, you can perform OCR analysis on the detected objects.</p> | |
""" | |
CSS = """ | |
#output { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
h1 { | |
text-align: center; | |
} | |
""" | |
EXAMPLES = [ | |
['examples/train1.png', 0.6, 0.25], | |
['examples/train2.png', 0.9, 0.25], | |
['examples/train3.png', 0.6, 0.25] | |
] | |
SHEET_LIST = ['Table_1', 'Table_2', 'Table_3', 'Table_4', 'Table_5', 'Table_6'] | |
with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo: | |
gr.HTML(TITLE) | |
gr.HTML(DESCRIPTION) | |
with gr.Tab(label="Identify objects"): | |
with gr.Row(equal_height=False): | |
input_img = gr.Image(type="filepath", label="Upload Image") | |
output_img = gr.Image(type="filepath", label="Output Image") | |
with gr.Row(): | |
with gr.Column(): | |
conf_thres = gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold") | |
with gr.Column(): | |
iou = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold") | |
with gr.Row(): | |
with gr.Column(): | |
submit_btn = gr.Button(value="Predict") | |
with gr.Column(): | |
analysis_btn = gr.Button(value="Analysis") | |
with gr.Row(): | |
output_df = gr.Dataframe(label="Results") | |
with gr.Row(): | |
with gr.Accordion("Gallery", open=False): | |
gallery = gr.Gallery(label="Detected Mark Object", columns=3) | |
submit_btn.click(inference, [input_img, conf_thres, iou], [output_img, gallery]) | |
analysis_btn.click(analysis, [], [output_df]) | |
examples = gr.Examples( | |
EXAMPLES, | |
fn=inference, | |
inputs=[input_img, conf_thres, iou], | |
outputs=[output_img, gallery], | |
cache_examples=False, | |
) | |
with gr.Tab(label="Detect and read table"): | |
with gr.Row(): | |
with gr.Column(): | |
upload_pdf = gr.Image(label="Upload PDF file") | |
upload_button = gr.UploadButton(label="Upload PDF file", file_types=[".pdf"]) | |
with gr.Column(): | |
output_img = gr.Image(label="Output Image", interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
conf_thres_table = gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, | |
label="Confidence Threshold") | |
with gr.Column(): | |
iou_table = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold") | |
with gr.Row(): | |
with gr.Column(): | |
text_output = gr.Textbox(label="Table List") | |
with gr.Column(): | |
file_output = gr.File() | |
with gr.Row(): | |
sheet_name = gr.Dropdown(choices=SHEET_LIST, allow_custom_value=True, label="Sheet Name") | |
with gr.Row(): | |
output_df = gr.Dataframe(label="Results") | |
upload_button.upload(analyze_table, [upload_button, conf_thres_table, iou_table], | |
[upload_pdf, output_img, file_output, text_output]) | |
conf_thres_table.change(analyze_table, [upload_button, conf_thres_table, iou_table], | |
[upload_pdf, output_img, file_output, text_output]) | |
iou_table.change(analyze_table, [upload_button, conf_thres_table, iou_table], | |
[upload_pdf, output_img, file_output, text_output]) | |
sheet_name.change(read_table, sheet_name, output_df) | |
demo.launch(debug=True) |