Spaces:
Running
Running
import atexit | |
import base64 | |
import os | |
import shutil | |
import tempfile | |
import time | |
import fitz | |
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image, ImageEnhance | |
from transformers import AutoModel, AutoTokenizer | |
model_name = "ucaslcl/GOT-OCR2_0" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device) | |
model = model.eval().to(device) | |
# 创建一个持久的临时目录 | |
TEMP_DIR = tempfile.mkdtemp() | |
def cleanup(): | |
"""清理临时目录""" | |
shutil.rmtree(TEMP_DIR, ignore_errors=True) | |
# 确保在程序退出时清理临时目录 | |
atexit.register(cleanup) | |
def pdf_to_images(pdf_path): | |
images = [] | |
pdf_document = fitz.open(pdf_path) | |
for page_num in range(len(pdf_document)): | |
page = pdf_document.load_page(page_num) | |
zoom = 10 # 增加缩放比例到10 | |
mat = fitz.Matrix(zoom, zoom) | |
pix = page.get_pixmap(matrix=mat, alpha=False) | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
# 增对比度 | |
enhancer = ImageEnhance.Contrast(img) | |
img = enhancer.enhance(1.5) # 增加50%的对比度 | |
images.append(img) | |
pdf_document.close() | |
return images | |
def convert_pdf_to_images(file): | |
if file is None: | |
return "错误:未提供文件", None | |
try: | |
if not file.name.lower().endswith(".pdf"): | |
return "错误:请上传PDF文件", None | |
images = pdf_to_images(file.name) | |
image_paths = [] | |
for i, image in enumerate(images): | |
img_path = os.path.join(TEMP_DIR, f"page_{i+1}.png") | |
image.save(img_path, "PNG") | |
image_paths.append(img_path) | |
return "PDF转换为图片成功", image_paths | |
except Exception as e: | |
return f"错误: {str(e)}", None | |
def ocr_process(image, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()): | |
if image is None: | |
return "错误:未选择图片" | |
try: | |
progress(0, desc="开始处理...") | |
# 模拟OCR处理的不同阶段 | |
progress(0.2, desc="图像预处理...") | |
time.sleep(0.5) | |
progress(0.4, desc="文字识别中...") | |
time.sleep(0.5) | |
progress(0.6, desc="后处理...") | |
time.sleep(0.5) | |
result = process_single_image(image, got_mode, ocr_color, ocr_box) | |
progress(0.8, desc="生成结果...") | |
time.sleep(0.5) | |
progress(1, desc="处理完成") | |
return result | |
except Exception as e: | |
return f"错误: {str(e)}" | |
def process_single_image(image_path, got_mode, ocr_color, ocr_box): | |
result_path = f"{os.path.splitext(image_path)[0]}_result.html" | |
if "plain" in got_mode: | |
if "multi-crop" in got_mode: | |
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr") | |
else: | |
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color) | |
return res | |
elif "format" in got_mode: | |
if "multi-crop" in got_mode: | |
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path) | |
else: | |
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path) | |
if os.path.exists(result_path): | |
with open(result_path, "r", encoding="utf-8") as f: | |
html_content = f.read() | |
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8") | |
data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" | |
preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>' | |
download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>' | |
return f"{download_link}\n\n{preview}" | |
return "错误: 未知的OCR模式" | |
with gr.Blocks() as demo: | |
gr.Markdown("# OCR 图像识别") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
pdf_input = gr.File(label="上传PDF文件") | |
convert_button = gr.Button("转换PDF为图片") | |
with gr.Column(scale=2): | |
image_gallery = gr.Gallery(label="图片预览", columns=3) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
selected_image = gr.State(value=None) # 使用 gr.State 来存储选中的图片路径 | |
preview_image = gr.Image(label="选中的图片", type="filepath") | |
got_mode = gr.Dropdown( | |
choices=[ | |
"plain texts OCR", | |
"format texts OCR", | |
"plain multi-crop OCR", | |
"format multi-crop OCR", | |
"plain fine-grained OCR", | |
"format fine-grained OCR", | |
], | |
label="OCR模式", | |
value="plain texts OCR", | |
) | |
ocr_color = gr.Textbox(label="OCR颜色 (仅用于fine-grained模式)") | |
ocr_box = gr.Textbox(label="OCR边界框 (仅用于fine-grained模式)") | |
ocr_button = gr.Button("开始OCR识别") | |
with gr.Column(scale=2): | |
ocr_output = gr.HTML(label="识别结果") | |
def select_image(evt: gr.SelectData, gallery): | |
selected = gallery[evt.index] | |
return selected, selected | |
image_gallery.select(select_image, image_gallery, [selected_image, preview_image]) | |
convert_button.click(convert_pdf_to_images, inputs=[pdf_input], outputs=[gr.Textbox(visible=False), image_gallery]) | |
ocr_button.click(ocr_process, inputs=[selected_image, got_mode, ocr_color, ocr_box], outputs=ocr_output) | |
if __name__ == "__main__": | |
demo.launch() | |