Mageia commited on
Commit
61a37e8
1 Parent(s): ffc8ac2

fix: process pdf once

Browse files
Files changed (1) hide show
  1. app.py +17 -77
app.py CHANGED
@@ -1,99 +1,39 @@
1
- import base64
2
- import os
3
- import uuid
4
-
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoConfig, AutoModel, AutoTokenizer
9
 
10
- # 初始化模型和分词器
11
- model_name = "stepfun-ai/GOT-OCR2_0"
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
16
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map="cuda", use_safetensors=True)
17
  model = model.eval().to(device)
18
- model.config.pad_token_id = tokenizer.eos_token_id
19
-
20
- UPLOAD_FOLDER = "./uploads"
21
-
22
- # 确保上传文件夹存在
23
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
24
 
25
 
26
  @spaces.GPU()
27
- def got_ocr(model, tokenizer, image_path, got_mode="format texts OCR", fine_grained_mode="", ocr_color="", ocr_box=""):
28
- # 执行OCR
29
- try:
30
- if got_mode == "plain texts OCR":
31
- res = model.chat(tokenizer, image_path, ocr_type="ocr")
32
- return res, None
33
- elif got_mode == "format texts OCR":
34
- result_path = f"{os.path.splitext(image_path)[0]}_result.html"
35
- res = model.chat(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
36
- elif got_mode == "plain multi-crop OCR":
37
- res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
38
- return res, None
39
- elif got_mode == "format multi-crop OCR":
40
- result_path = f"{os.path.splitext(image_path)[0]}_result.html"
41
- res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
42
- elif got_mode == "plain fine-grained OCR":
43
- res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
44
- return res, None
45
- elif got_mode == "format fine-grained OCR":
46
- result_path = f"{os.path.splitext(image_path)[0]}_result.html"
47
- res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
48
-
49
- # 处理格式化结果
50
- if "format" in got_mode and os.path.exists(result_path):
51
- with open(result_path, "r") as f:
52
- html_content = f.read()
53
- encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
54
- return res, encoded_html
55
- else:
56
- return res, None
57
-
58
- except Exception as e:
59
- return f"错误: {str(e)}", None
60
-
61
-
62
- def perform_ocr(image):
63
  if image is None:
64
- return "请上传图片"
65
-
66
- # 保存上传的图片
67
- image_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}.png")
68
- image.save(image_path)
69
 
70
- # 执行OCR
71
- result, html_content = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
72
-
73
- # 删除临时文件
74
- os.remove(image_path)
75
-
76
- if html_content:
77
- encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
78
- iframe_src = f"data:text/html;base64,{encoded_html}"
79
- iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
80
- download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result.html">下载完整结果</a>'
81
- return gr.HTML(f"{download_link}<br>{iframe}")
82
- else:
83
- return gr.Markdown(result)
84
 
85
 
86
- # 创建 Gradio 界面
87
  with gr.Blocks() as demo:
88
  gr.Markdown("# OCR 图像识别")
 
89
  with gr.Row():
90
- image_input = gr.Image(type="pil", label="上传图片")
91
- with gr.Row():
92
- ocr_button = gr.Button("开始OCR识别")
93
- with gr.Row():
94
- output = gr.HTML(label="OCR结果")
95
 
96
- ocr_button.click(fn=perform_ocr, inputs=image_input, outputs=output)
97
 
98
  if __name__ == "__main__":
99
  demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from transformers import AutoModel, AutoTokenizer
5
 
6
+ model_name = "ucaslcl/GOT-OCR2_0"
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
 
11
  model = model.eval().to(device)
 
 
 
 
 
 
12
 
13
 
14
  @spaces.GPU()
15
+ def ocr_process(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  if image is None:
17
+ return "错误:未提供图片"
 
 
 
 
18
 
19
+ try:
20
+ res = model.chat(tokenizer, image, ocr_type="ocr")
21
+ return res
22
+ except Exception as e:
23
+ return f"错误: {str(e)}"
 
 
 
 
 
 
 
 
 
24
 
25
 
 
26
  with gr.Blocks() as demo:
27
  gr.Markdown("# OCR 图像识别")
28
+
29
  with gr.Row():
30
+ image_input = gr.Image(type="filepath", label="上传图片")
31
+
32
+ submit_button = gr.Button("开始OCR识别")
33
+
34
+ output_text = gr.Textbox(label="识别结果")
35
 
36
+ submit_button.click(ocr_process, inputs=[image_input], outputs=[output_text])
37
 
38
  if __name__ == "__main__":
39
  demo.launch()