lr_pdf / Document_QA.py
coldlarry's picture
add paper
a34348d
import openai
import faiss
import numpy as np
import pickle
from tqdm import tqdm
import argparse
import os
from PyPDF2 import PdfReader
class Paper(object):
def __init__(self, pdf_obj: PdfReader) -> None:
self._pdf_obj = pdf_obj
self._paper_meta = self._pdf_obj.metadata
self.texts = []
def iter_pages(self, iter_text_len: int = 1000):
page_idx = 0
for page in self._pdf_obj.pages:
txt = page.extract_text()
for i in range((len(txt) // iter_text_len) + 1):
yield page_idx, i, txt[i * iter_text_len:(i + 1) * iter_text_len]
page_idx += 1
def get_texts(self):
for (page_idx, part_idx, text) in self.iter_pages():
self.texts.append(text.strip())
return self.texts
def create_embeddings(inputs):
"""Create embeddings for the provided input."""
# input = ['ddd','aaa','ccccccccccccccc','ddddd']
result = []
tokens = 0
def get_embedding(input_slice):
input_slice = [input_slice]
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens
for i in range(0,len(inputs)):
ebd, tk = get_embedding(inputs[i])
tokens += tk
result.extend(ebd)
return result, tokens
def create_embedding(text):
"""Create an embedding for the provided text."""
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=text)
return text, embedding.data[0].embedding
class QA():
def __init__(self,data_embe) -> None:
d = 1536
index = faiss.IndexFlatL2(d)
embe = np.array([emm[1] for emm in data_embe])
data = [emm[0] for emm in data_embe]
index.add(embe)
#所有emdding
self.index = index
#所有文字
self.data = data
print("now all data is:\n",self.data)
def __call__(self, query):
embedding = create_embedding(query)
#输出与用户的问题相关的文字
context = self.get_texts(embedding[1])
#将用户的问题和涉及的文字告诉gpt,并将答案返回
answer = self.completion(query,context)
return answer,context
def get_texts(self,embeding,limit=5):
_,text_index = self.index.search(np.array([embeding]),limit)
context = []
for i in list(text_index[0]):
context.extend(self.data[i:i+2])
# context = [self.data[i] for i in list(text_index[0])]
#输出与用户的问题相关的文字
return context
def completion(self,query, context):
"""Create a completion."""
# lens = [len(text) for text in context]
# maximum = 3000
# for index, l in enumerate(lens):
# maximum -= l
# if maximum < 0:
# context = context[:index + 1]
# print("超过最大长度,截断到前", index + 1, "个片段")
# break
text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{'role': 'system',
'content': f'你是一个有帮助的AI文章助手,从下文中提取有用的内容进行回答,不能回答不在下文提到的内容,相关性从高到底排序:\n\n{text}'},
{'role': 'user', 'content': query},
],
)
print("使用的tokens:", response.usage.total_tokens)
return response.choices[0].message.content
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Document QA")
parser.add_argument("--input_file", default="slimming-pages-1.pdf", dest="input_file", type=str,help="输入文件路径")
# parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径")
parser.add_argument("--print_context", action='store_true',help="是否打印上下文")
args = parser.parse_args()
# if os.path.isfile(args.file_embeding):
# data_embe = pickle.load(open(args.file_embeding,'rb'))
# else:
# with open(args.input_file,'r',encoding='utf-8') as f:
# texts = f.readlines()
# #按照行对文章进行切割
# texts = [text.strip() for text in texts if text.strip()]
# data_embe,tokens = create_embeddings(texts)
# pickle.dump(data_embe,open(args.file_embeding,'wb'))
# print("文本消耗 {} tokens".format(tokens))
paper = Paper(args.input_file)
all_texts = paper.get_texts()
data_embe, tokens = create_embeddings(all_texts)
print("全部文本消耗 {} tokens".format(tokens))
qa =QA(data_embe)
limit = 10
while True:
query = input("请输入查询(help可查看指令):")
if query == "quit":
break
elif query.startswith("limit"):
try:
limit = int(query.split(" ")[1])
print("已设置limit为", limit)
except Exception as e:
print("设置limit失败", e)
continue
elif query == "help":
print("输入limit [数字]设置limit")
print("输入quit退出")
continue
answer,context = qa(query)
if args.print_context:
print("已找到相关片段:")
for text in context:
print('\t', text)
print("=====================================")
print("回答如下\n\n")
print(answer.strip())
print("=====================================")