Spaces:
Running
Running
fix: process pdf once
Browse files
app.py
CHANGED
@@ -2,15 +2,12 @@ import base64
|
|
2 |
import os
|
3 |
import uuid
|
4 |
|
|
|
5 |
import torch
|
6 |
-
from fastapi import FastAPI, File, UploadFile
|
7 |
-
from fastapi.responses import JSONResponse
|
8 |
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
9 |
|
10 |
from got_ocr import got_ocr
|
11 |
|
12 |
-
app = FastAPI()
|
13 |
-
|
14 |
# 初始化模型和分词器
|
15 |
model_name = "ucaslcl/GOT-OCR2_0"
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -27,12 +24,13 @@ UPLOAD_FOLDER = "./uploads"
|
|
27 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
28 |
|
29 |
|
30 |
-
|
31 |
-
|
|
|
|
|
32 |
# 保存上传的图片
|
33 |
image_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}.png")
|
34 |
-
|
35 |
-
buffer.write(await image.read())
|
36 |
|
37 |
# 执行OCR
|
38 |
result, html_content = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
|
@@ -40,15 +38,27 @@ async def perform_ocr(image: UploadFile = File(...)):
|
|
40 |
# 删除临时文件
|
41 |
os.remove(image_path)
|
42 |
|
43 |
-
# 准备响应
|
44 |
-
response = {"result": result}
|
45 |
if html_content:
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
if __name__ == "__main__":
|
52 |
-
|
53 |
-
|
54 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
2 |
import os
|
3 |
import uuid
|
4 |
|
5 |
+
import gradio as gr
|
6 |
import torch
|
|
|
|
|
7 |
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
8 |
|
9 |
from got_ocr import got_ocr
|
10 |
|
|
|
|
|
11 |
# 初始化模型和分词器
|
12 |
model_name = "ucaslcl/GOT-OCR2_0"
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
24 |
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
25 |
|
26 |
|
27 |
+
def perform_ocr(image):
|
28 |
+
if image is None:
|
29 |
+
return "请上传图片"
|
30 |
+
|
31 |
# 保存上传的图片
|
32 |
image_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}.png")
|
33 |
+
image.save(image_path)
|
|
|
34 |
|
35 |
# 执行OCR
|
36 |
result, html_content = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
|
|
|
38 |
# 删除临时文件
|
39 |
os.remove(image_path)
|
40 |
|
|
|
|
|
41 |
if html_content:
|
42 |
+
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
|
43 |
+
iframe_src = f"data:text/html;base64,{encoded_html}"
|
44 |
+
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
|
45 |
+
download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result.html">下载完整结果</a>'
|
46 |
+
return gr.HTML(f"{download_link}<br>{iframe}")
|
47 |
+
else:
|
48 |
+
return gr.Markdown(result)
|
49 |
+
|
50 |
+
|
51 |
+
# 创建 Gradio 界面
|
52 |
+
with gr.Blocks() as demo:
|
53 |
+
gr.Markdown("# OCR 图像识别")
|
54 |
+
with gr.Row():
|
55 |
+
image_input = gr.Image(type="pil", label="上传图片")
|
56 |
+
with gr.Row():
|
57 |
+
ocr_button = gr.Button("开始OCR识别")
|
58 |
+
with gr.Row():
|
59 |
+
output = gr.HTML(label="OCR结果")
|
60 |
+
|
61 |
+
ocr_button.click(fn=perform_ocr, inputs=image_input, outputs=output)
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
+
demo.launch()
|
|
|
|