breeztest / app.py
simonzhang5429's picture
Update app.py
f669dd9 verified
raw
history blame contribute delete
No virus
3.38 kB
import gradio as gr
import os
import shutil
from pypdf import PdfReader
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import fitz
TOKENIZER_REPO = "MediaTek-Research/Breeze-7B-Instruct-v1_0"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO,local_files_only=False,use_fast=True)
tran_hints = "请将以下的文字转为繁体:"
start_flag="<s>"
end_flag="</s>"
model = AutoModelForCausalLM.from_pretrained(
TOKENIZER_REPO,
device_map="auto",
local_files_only=False,
torch_dtype=torch.bfloat16
)
def generate(text):
chat_data = []
text = text.strip()
if text:
chat_data.append({"role": "user", "content": text})
achat=tokenizer.apply_chat_template(chat_data,return_tensors="pt")
#achat=tokenizer.encode(chat_data,return_tensors="pt",max_length=2048)
outputs = model.generate(achat,
max_new_tokens=2048,
top_p=0.01,
top_k=85,
repetition_penalty=1.1,
temperature=0)
return tokenizer.decode(outputs[0])
def tran_txt(input_txt):
data_txt=tran_hints+"\n"+input_txt.strip()
tran_result=generate(data_txt)
print("tran_result="+tran_result)
# tran_result=tran_result.strip()
# index=tran_result.find(start_flag)
# if index>=0:
# tran_result=tran_result[len(start_flag):]
# tran_result=tran_result.strip()
# c_index=tran_result.find(data_txt)
# if c_index>=0:
# tran_result=tran_result[len(data_txt):]
# e_index=tran_result.find(end_flag)
# if e_index>=0:
# tran_result=tran_result[0:e_index]
return tran_result
def exec_tran(file):
temp_file=upload_file(file)
page_texts=read_paragraphs(temp_file)
temp_result_file=file;
file_index=temp_result_file.index('.pdf')
if file_index!=-1:
temp_result_file=temp_result_file[0:file_index]
temp_result_file=temp_result_file+"_result.txt"
else :
temp_result_file=temp_result_file+"_result.txt"
tran_file_name=file.name
with open(temp_result_file,'w') as fw:
tran_result=tran_txt(tran_hints)
# print(tran_result+"\n")
for page_content in page_texts:
#lines=page_content.split('\n')
#for line_content in lines:
#print("input="+line_content)
tran_result=tran_txt(page_content)
# print("result="+tran_result)
fw.write(tran_result+"\n")
return temp_result_file
def upload_file(file):
UPLOAD_FOLDER="./data"
if not os.path.exists(UPLOAD_FOLDER):
os.mkdir(UPLOAD_FOLDER)
return shutil.copy(file,UPLOAD_FOLDER)
def read_paragraphs(pdf_path):
document = fitz.open(pdf_path)
paragraphs = []
for page in document:
text = page.get_text("paragraphs")
para_list = text.split('。')
paragraphs.extend([para for para in para_list if para.strip()])
document.close()
return paragraphs
def load_pdf_pages(filename):
page_texts=[]
reader = PdfReader(filename)
for page in reader.pages:
page_texts.append(page.extract_text())
return page_texts
def exec_translate(file):
upload_file(file)
page_texts=load_pdf_pages(file.name)
with gr.Blocks() as app:
file_output=gr.File()
upload_button=gr.UploadButton("上传pdf文件",file_types=["pdf"],file_count="single")
upload_button.upload(exec_tran,upload_button,file_output)
app.launch()