File size: 4,467 Bytes
dd08fd0
696caca
dd08fd0
696caca
dd08fd0
696caca
dd08fd0
696caca
 
dd08fd0
696caca
 
 
 
 
dd08fd0
696caca
 
 
 
 
 
 
 
dd08fd0
696caca
 
 
 
 
 
 
 
 
 
 
 
 
dd08fd0
 
 
 
696caca
dd08fd0
 
696caca
dd08fd0
 
696caca
dd08fd0
 
 
 
696caca
dd08fd0
 
 
 
696caca
dd08fd0
 
696caca
dd08fd0
 
 
696caca
dd08fd0
 
 
 
696caca
 
dd08fd0
 
696caca
dd08fd0
 
 
 
 
 
 
 
 
 
 
696caca
dd08fd0
 
 
 
 
 
 
 
696caca
dd08fd0
696caca
 
 
dd08fd0
 
 
696caca
dd08fd0
696caca
dd08fd0
 
696caca
 
dd08fd0
 
 
696caca
dd08fd0
696caca
dd08fd0
696caca
dd08fd0
696caca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import base64
import os
import tempfile

import fitz
import gradio as gr
import spaces
import torch
from PIL import Image, ImageEnhance
from transformers import AutoModel, AutoTokenizer

model_name = "ucaslcl/GOT-OCR2_0"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
model = model.eval().to(device)


def pdf_to_images(pdf_path):
    images = []
    pdf_document = fitz.open(pdf_path)
    for page_num in range(len(pdf_document)):
        page = pdf_document.load_page(page_num)
        zoom = 10  # 增加缩放比例到10
        mat = fitz.Matrix(zoom, zoom)
        pix = page.get_pixmap(matrix=mat, alpha=False)
        img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)

        # 增对比度
        enhancer = ImageEnhance.Contrast(img)
        img = enhancer.enhance(1.5)  # 增加50%的对比度

        images.append(img)
    pdf_document.close()
    return images


@spaces.GPU()
def ocr_process(file, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
    if file is None:
        return "错误:未提供文件"

    try:
        progress(0, desc="开始处理...")

        with tempfile.TemporaryDirectory() as temp_dir:
            file_path = file.name  # 使用文件的原始路径

            if file_path.lower().endswith(".pdf"):
                images = pdf_to_images(file_path)
                num_pages = len(images)
                results = []

                for i, image in enumerate(images):
                    progress((i + 1) / num_pages, desc=f"处理第 {i+1}/{num_pages} 页...")
                    img_path = os.path.join(temp_dir, f"page_{i+1}.png")
                    image.save(img_path, "PNG")

                    result = process_single_image(img_path, got_mode, ocr_color, ocr_box)
                    results.append(f"第 {i+1} 页结果:\n{result}")

                final_result = "\n\n".join(results)
            else:
                final_result = process_single_image(file_path, got_mode, ocr_color, ocr_box)

        progress(1, desc="处理完成")
        return final_result
    except Exception as e:
        return f"错误: {str(e)}"


def process_single_image(image_path, got_mode, ocr_color, ocr_box):
    result_path = f"{os.path.splitext(image_path)[0]}_result.html"

    if "plain" in got_mode:
        if "multi-crop" in got_mode:
            res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
        else:
            res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
        return res
    elif "format" in got_mode:
        if "multi-crop" in got_mode:
            res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
        else:
            res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)

        if os.path.exists(result_path):
            with open(result_path, "r", encoding="utf-8") as f:
                html_content = f.read()
            encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
            data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
            preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>'
            download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>'
            return f"{download_link}\n\n{preview}"

    return "错误: 未知的OCR模式"


with gr.Blocks() as demo:
    gr.Markdown("# OCR 图像识别")

    file_input = gr.File(label="上传PDF或图片文件")

    got_mode = gr.Dropdown(
        choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
        label="OCR模式",
        value="plain texts OCR",
    )

    with gr.Row():
        ocr_color = gr.Textbox(label="OCR颜色 (仅用于fine-grained模式)")
        ocr_box = gr.Textbox(label="OCR边界框 (仅用于fine-grained模式)")

    submit_button = gr.Button("开始OCR识别")

    output = gr.HTML(label="识别结果")

    submit_button.click(ocr_process, inputs=[file_input, got_mode, ocr_color, ocr_box], outputs=output)

if __name__ == "__main__":
    demo.launch()