File size: 4,956 Bytes
b764ffe
 
0aa924d
be425b2
8242b6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec23149
0aa924d
8242b6e
 
 
 
 
 
 
 
 
 
 
 
e134b51
8242b6e
 
 
 
 
 
 
6cd21dc
c4dd123
8242b6e
0aa924d
8242b6e
 
 
 
 
 
 
 
 
 
 
0aa924d
8242b6e
 
 
 
 
 
 
 
 
0aa924d
8242b6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa924d
 
c4dd123
0aa924d
 
 
c4dd123
8242b6e
f7a222c
8242b6e
f7a222c
8242b6e
 
6cd21dc
8242b6e
0aa924d
8242b6e
b764ffe
0aa924d
c4dd123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from ultralytics import YOLO
import spaces
import torch
import cv2
import numpy as np
import os
import requests

# Define constants for the new model
ENTITIES_COLORS = {
    "Caption": (191, 100, 21),
    "Footnote": (2, 62, 115),
    "Formula": (140, 80, 58),
    "List-item": (168, 181, 69),
    "Page-footer": (2, 69, 84),
    "Page-header": (83, 115, 106),
    "Picture": (255, 72, 88),
    "Section-header": (0, 204, 192),
    "Table": (116, 127, 127),
    "Text": (0, 153, 221),
    "Title": (196, 51, 2)
}
BOX_PADDING = 2

# Load pre-trained YOLOv8 model
model_path_1 = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
model_path_2 = "models/dla-model.pt"

if not os.path.exists(model_path_1):
    # Download the model file if it doesn't exist
    model_url_1 = "https://huggingface.co/DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet/resolve/main/yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
    response = requests.get(model_url_1)
    with open(model_path_1, "wb") as f:
        f.write(response.content)

if not os.path.exists(model_path_2):
    # Assume the second model file is manually uploaded in the specified path

# Load models
model_1 = YOLO(model_path_1)
model_2 = YOLO(model_path_2)

# Get class names from the first model
class_names_1 = model_1.names
class_names_2 = list(ENTITIES_COLORS.keys())

@spaces.GPU(duration=60)
def process_image(image, model_choice):
    try:
        if model_choice == "YOLOv8 Model":
            # Use the first model
            results = model_1(source=image, save=False, show_labels=True, show_conf=True, show_boxes=True)
            result = results[0]

            # Extract annotated image and labels with class names
            annotated_image = result.plot()

            detected_areas_labels = "\n".join([
                f"{class_names_1[int(box.cls.item())].upper()}: {float(box.conf):.2f}" for box in result.boxes
            ])

            return annotated_image, detected_areas_labels
        
        elif model_choice == "DLA Model":
            # Use the second model
            image_path = "input_image.jpg"  # Temporary save the uploaded image
            cv2.imwrite(image_path, cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
            image = cv2.imread(image_path)
            results = model_2.predict(source=image, conf=0.2, iou=0.8)
            boxes = results[0].boxes

            if len(boxes) == 0:
                return image

            for box in boxes:
                detection_class_conf = round(box.conf.item(), 2)
                cls = class_names_2[int(box.cls)]
                start_box = (int(box.xyxy[0][0]), int(box.xyxy[0][1]))
                end_box = (int(box.xyxy[0][2]), int(box.xyxy[0][3]))

                line_thickness = round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1
                image = cv2.rectangle(img=image, 
                                      pt1=start_box, 
                                      pt2=end_box,
                                      color=ENTITIES_COLORS[cls], 
                                      thickness=line_thickness)
                
                text = cls + " " + str(detection_class_conf)
                font_thickness = max(line_thickness - 1, 1)
                (text_w, text_h), _ = cv2.getTextSize(text=text, fontFace=2, fontScale=line_thickness/3, thickness=font_thickness)
                image = cv2.rectangle(img=image,
                                      pt1=(start_box[0], start_box[1] - text_h - BOX_PADDING*2),
                                      pt2=(start_box[0] + text_w + BOX_PADDING * 2, start_box[1]),
                                      color=ENTITIES_COLORS[cls],
                                      thickness=-1)
                start_text = (start_box[0] + BOX_PADDING, start_box[1] - BOX_PADDING)
                image = cv2.putText(img=image, text=text, org=start_text, fontFace=0, color=(255,255,255), fontScale=line_thickness/3, thickness=font_thickness)
            
            return cv2.cvtColor(image, cv2.COLOR_BGR2RGB), "Labels: " + ", ".join(class_names_2)
        
        else:
            return None, "Invalid model choice"

    except Exception as e:
        return None, f"Error processing image: {e}"

# Create the Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Document Segmentation Demo (ZeroGPU)")
    
    with gr.Row():
        model_choice = gr.Dropdown(["YOLOv8 Model", "DLA Model"], label="Select Model", value="YOLOv8 Model")
        input_image = gr.Image(type="pil", label="Upload Image")
    
    output_image = gr.Image(type="pil", label="Annotated Image")
    output_text = gr.Textbox(label="Detected Areas and Labels")
    
    btn = gr.Button("Run Document Segmentation")
    btn.click(fn=process_image, inputs=[input_image, model_choice], outputs=[output_image, output_text])

# Launch the demo with queuing
demo.queue(max_size=1).launch()