Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ["GRADIO_TEMP_DIR"] = "./tmp" | |
import sys | |
import spaces | |
import torch | |
import torchvision | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
from visualization import visualize_bbox | |
# == download weights == | |
model_dir = snapshot_download('juliozhao/DocLayout-YOLO-DocStructBench', local_dir='./models/DocLayout-YOLO-DocStructBench') | |
# == select device == | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
id_to_names = { | |
0: 'title', | |
1: 'plain text', | |
2: 'abandon', | |
3: 'figure', | |
4: 'figure_caption', | |
5: 'table', | |
6: 'table_caption', | |
7: 'table_footnote', | |
8: 'isolate_formula', | |
9: 'formula_caption' | |
} | |
def recognize_image(input_img, conf_threshold, iou_threshold): | |
det_res = model.predict( | |
input_img, | |
imgsz=1024, | |
conf=conf_threshold, | |
device=device, | |
)[0] | |
boxes = det_res.__dict__['boxes'].xyxy | |
classes = det_res.__dict__['boxes'].cls | |
scores = det_res.__dict__['boxes'].conf | |
indices = torchvision.ops.nms(boxes=torch.Tensor(boxes), scores=torch.Tensor(scores),iou_threshold=iou_threshold) | |
boxes, scores, classes = boxes[indices], scores[indices], classes[indices] | |
if len(boxes.shape) == 1: | |
boxes = np.expand_dims(boxes, 0) | |
scores = np.expand_dims(scores, 0) | |
classes = np.expand_dims(classes, 0) | |
vis_result = visualize_bbox(input_img, boxes, classes, scores, id_to_names) | |
return vis_result | |
def gradio_reset(): | |
return gr.update(value=None), gr.update(value=None) | |
if __name__ == "__main__": | |
root_path = os.path.abspath(os.getcwd()) | |
# == load model == | |
from doclayout_yolo import YOLOv10 | |
print(f"Using device: {device}") | |
model = YOLOv10(os.path.join(os.path.dirname(__file__), "models", "DocLayout-YOLO-DocStructBench", "doclayout_yolo_docstructbench_imgsz1024.pt")) # load an official model | |
with open("header.html", "r") as file: | |
header = file.read() | |
with gr.Blocks() as demo: | |
gr.HTML(header) | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(label=" ", interactive=True) | |
with gr.Row(): | |
clear = gr.Button(value="Clear") | |
predict = gr.Button(value="Detect", interactive=True, variant="primary") | |
with gr.Row(): | |
conf_threshold = gr.Slider( | |
label="Confidence Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.25, | |
) | |
with gr.Row(): | |
iou_threshold = gr.Slider( | |
label="NMS IOU Threshold", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.05, | |
value=0.45, | |
) | |
with gr.Accordion("Examples:"): | |
example_root = os.path.join(os.path.dirname(__file__), "assets", "example") | |
gr.Examples( | |
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if | |
_.endswith("jpg")], | |
inputs=[input_img], | |
) | |
with gr.Column(): | |
gr.Button(value="Predict Result:", interactive=False) | |
output_img = gr.Image(label=" ", interactive=False) | |
clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img]) | |
predict.click(recognize_image, inputs=[input_img,conf_threshold,iou_threshold], outputs=[output_img]) | |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |