atlury's picture
Update app.py
8242b6e verified
raw
history blame
4.96 kB
import gradio as gr
from ultralytics import YOLO
import spaces
import torch
import cv2
import numpy as np
import os
import requests
# Define constants for the new model
ENTITIES_COLORS = {
"Caption": (191, 100, 21),
"Footnote": (2, 62, 115),
"Formula": (140, 80, 58),
"List-item": (168, 181, 69),
"Page-footer": (2, 69, 84),
"Page-header": (83, 115, 106),
"Picture": (255, 72, 88),
"Section-header": (0, 204, 192),
"Table": (116, 127, 127),
"Text": (0, 153, 221),
"Title": (196, 51, 2)
}
BOX_PADDING = 2
# Load pre-trained YOLOv8 model
model_path_1 = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
model_path_2 = "models/dla-model.pt"
if not os.path.exists(model_path_1):
# Download the model file if it doesn't exist
model_url_1 = "https://huggingface.co/DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet/resolve/main/yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
response = requests.get(model_url_1)
with open(model_path_1, "wb") as f:
f.write(response.content)
if not os.path.exists(model_path_2):
# Assume the second model file is manually uploaded in the specified path
# Load models
model_1 = YOLO(model_path_1)
model_2 = YOLO(model_path_2)
# Get class names from the first model
class_names_1 = model_1.names
class_names_2 = list(ENTITIES_COLORS.keys())
@spaces.GPU(duration=60)
def process_image(image, model_choice):
try:
if model_choice == "YOLOv8 Model":
# Use the first model
results = model_1(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
result = results[0]
# Extract annotated image and labels with class names
annotated_image = result.plot()
detected_areas_labels = "\n".join([
f"{class_names_1[int(box.cls.item())].upper()}: {float(box.conf):.2f}" for box in result.boxes
])
return annotated_image, detected_areas_labels
elif model_choice == "DLA Model":
# Use the second model
image_path = "input_image.jpg" # Temporary save the uploaded image
cv2.imwrite(image_path, cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
image = cv2.imread(image_path)
results = model_2.predict(source=image, conf=0.2, iou=0.8)
boxes = results[0].boxes
if len(boxes) == 0:
return image
for box in boxes:
detection_class_conf = round(box.conf.item(), 2)
cls = class_names_2[int(box.cls)]
start_box = (int(box.xyxy[0][0]), int(box.xyxy[0][1]))
end_box = (int(box.xyxy[0][2]), int(box.xyxy[0][3]))
line_thickness = round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1
image = cv2.rectangle(img=image,
pt1=start_box,
pt2=end_box,
color=ENTITIES_COLORS[cls],
thickness=line_thickness)
text = cls + " " + str(detection_class_conf)
font_thickness = max(line_thickness - 1, 1)
(text_w, text_h), _ = cv2.getTextSize(text=text, fontFace=2, fontScale=line_thickness/3, thickness=font_thickness)
image = cv2.rectangle(img=image,
pt1=(start_box[0], start_box[1] - text_h - BOX_PADDING*2),
pt2=(start_box[0] + text_w + BOX_PADDING * 2, start_box[1]),
color=ENTITIES_COLORS[cls],
thickness=-1)
start_text = (start_box[0] + BOX_PADDING, start_box[1] - BOX_PADDING)
image = cv2.putText(img=image, text=text, org=start_text, fontFace=0, color=(255,255,255), fontScale=line_thickness/3, thickness=font_thickness)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), "Labels: " + ", ".join(class_names_2)
else:
return None, "Invalid model choice"
except Exception as e:
return None, f"Error processing image: {e}"
# Create the Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Document Segmentation Demo (ZeroGPU)")
with gr.Row():
model_choice = gr.Dropdown(["YOLOv8 Model", "DLA Model"], label="Select Model", value="YOLOv8 Model")
input_image = gr.Image(type="pil", label="Upload Image")
output_image = gr.Image(type="pil", label="Annotated Image")
output_text = gr.Textbox(label="Detected Areas and Labels")
btn = gr.Button("Run Document Segmentation")
btn.click(fn=process_image, inputs=[input_image, model_choice], outputs=[output_image, output_text])
# Launch the demo with queuing
demo.queue(max_size=1).launch()