Mageia commited on
Commit
79746f6
1 Parent(s): 4d31938

fix: process pdf once

Browse files
Files changed (1) hide show
  1. app.py +27 -17
app.py CHANGED
@@ -2,15 +2,12 @@ import base64
2
  import os
3
  import uuid
4
 
 
5
  import torch
6
- from fastapi import FastAPI, File, UploadFile
7
- from fastapi.responses import JSONResponse
8
  from transformers import AutoConfig, AutoModel, AutoTokenizer
9
 
10
  from got_ocr import got_ocr
11
 
12
- app = FastAPI()
13
-
14
  # 初始化模型和分词器
15
  model_name = "ucaslcl/GOT-OCR2_0"
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -27,12 +24,13 @@ UPLOAD_FOLDER = "./uploads"
27
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
28
 
29
 
30
- @app.post("/ocr")
31
- async def perform_ocr(image: UploadFile = File(...)):
 
 
32
  # 保存上传的图片
33
  image_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}.png")
34
- with open(image_path, "wb") as buffer:
35
- buffer.write(await image.read())
36
 
37
  # 执行OCR
38
  result, html_content = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
@@ -40,15 +38,27 @@ async def perform_ocr(image: UploadFile = File(...)):
40
  # 删除临时文件
41
  os.remove(image_path)
42
 
43
- # 准备响应
44
- response = {"result": result}
45
  if html_content:
46
- response["html_content"] = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
47
-
48
- return JSONResponse(content=response)
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  if __name__ == "__main__":
52
- import uvicorn
53
-
54
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  import os
3
  import uuid
4
 
5
+ import gradio as gr
6
  import torch
 
 
7
  from transformers import AutoConfig, AutoModel, AutoTokenizer
8
 
9
  from got_ocr import got_ocr
10
 
 
 
11
  # 初始化模型和分词器
12
  model_name = "ucaslcl/GOT-OCR2_0"
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
24
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
25
 
26
 
27
+ def perform_ocr(image):
28
+ if image is None:
29
+ return "请上传图片"
30
+
31
  # 保存上传的图片
32
  image_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}.png")
33
+ image.save(image_path)
 
34
 
35
  # 执行OCR
36
  result, html_content = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
 
38
  # 删除临时文件
39
  os.remove(image_path)
40
 
 
 
41
  if html_content:
42
+ encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
43
+ iframe_src = f"data:text/html;base64,{encoded_html}"
44
+ iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
45
+ download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result.html">下载完整结果</a>'
46
+ return gr.HTML(f"{download_link}<br>{iframe}")
47
+ else:
48
+ return gr.Markdown(result)
49
+
50
+
51
+ # 创建 Gradio 界面
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("# OCR 图像识别")
54
+ with gr.Row():
55
+ image_input = gr.Image(type="pil", label="上传图片")
56
+ with gr.Row():
57
+ ocr_button = gr.Button("开始OCR识别")
58
+ with gr.Row():
59
+ output = gr.HTML(label="OCR结果")
60
+
61
+ ocr_button.click(fn=perform_ocr, inputs=image_input, outputs=output)
62
 
63
  if __name__ == "__main__":
64
+ demo.launch()