gemma_paper_qa / app.py
halyn's picture
code update
3e1aa0b
raw
history blame
5.37 kB
import io
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.question_answering import load_qa_chain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Global variables
knowledge_base = None
qa_chain = None
# PDF ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ํ…์ŠคํŠธ ์ถ”์ถœ
def load_pdf(pdf_file):
pdf_reader = PdfReader(pdf_file)
text = "".join(page.extract_text() for page in pdf_reader.pages)
return text
# ํ…์ŠคํŠธ๋ฅผ ์ฒญํฌ๋กœ ๋ถ„ํ• 
def split_text(text):
text_splitter = CharacterTextSplitter(
separator="\n",
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
return text_splitter.split_text(text)
# FAISS ๋ฒกํ„ฐ ์ €์žฅ์†Œ ์ƒ์„ฑ
def create_knowledge_base(chunks):
embeddings = HuggingFaceEmbeddings()
return FAISS.from_texts(chunks, embeddings)
# Hugging Face ๋ชจ๋ธ ๋กœ๋“œ
def load_model():
model_name = "halyn/gemma2-2b-it-finetuned-paperqa"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
# QA ์ฒด์ธ ์„ค์ •
def setup_qa_chain():
global qa_chain
pipe = load_model()
llm = HuggingFacePipeline(pipeline=pipe)
qa_chain = load_qa_chain(llm, chain_type="stuff")
# ๋ฉ”์ธ ํŽ˜์ด์ง€ UI
def main_page():
st.title("Welcome to GemmaPaperQA")
st.subheader("Upload Your Paper")
paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden")
if paper:
st.write(f"Upload complete! File name: {paper.name}")
# ํŒŒ์ผ ํฌ๊ธฐ ํ™•์ธ
file_size = paper.size # ํŒŒ์ผ ํฌ๊ธฐ๋ฅผ ํŒŒ์ผ ํฌ์ธํ„ฐ ์ด๋™ ์—†์ด ํ™•์ธ
if file_size > 10 * 1024 * 1024: # 10MB ์ œํ•œ
st.error("File is too large! Please upload a file smaller than 10MB.")
return
# ์ค‘๊ฐ„ ํ™•์ธ ์ ˆ์ฐจ - PDF ๋‚ด์šฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
with st.spinner('Processing PDF...'):
try:
paper.seek(0) # ํŒŒ์ผ ์ฝ๊ธฐ ํฌ์ธํ„ฐ๋ฅผ ์ฒ˜์Œ์œผ๋กœ ๋˜๋Œ๋ฆผ
contents = paper.read()
pdf_file = io.BytesIO(contents)
text = load_pdf(pdf_file)
# ํ…์ŠคํŠธ๊ฐ€ ์ถ”์ถœ๋˜์ง€ ์•Š์„ ๊ฒฝ์šฐ ์—๋Ÿฌ ์ฒ˜๋ฆฌ
if len(text.strip()) == 0:
st.error("The PDF appears to have no extractable text. Please check the file and try again.")
return
st.text_area("Preview of extracted text", text[:1000], height=200)
st.write(f"Total characters extracted: {len(text)}")
if st.button("Proceed with this file"):
chunks = split_text(text)
global knowledge_base
knowledge_base = create_knowledge_base(chunks)
if knowledge_base is None:
st.error("Failed to create knowledge base.")
return
st.session_state.paper_name = paper.name[:-4]
st.session_state.page = "chat"
setup_qa_chain()
st.success("PDF successfully processed! You can now ask questions.")
except Exception as e:
st.error(f"Failed to process the PDF: {str(e)}")
# ์ฑ„ํŒ… ํŽ˜์ด์ง€ UI
def chat_page():
st.title(f"Ask anything about {st.session_state.paper_name}")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Chat here!"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
response = get_response_from_model(prompt)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if st.button("Go back to main page"):
st.session_state.page = "main"
# ๋ชจ๋ธ ์‘๋‹ต ์ฒ˜๋ฆฌ
def get_response_from_model(prompt):
try:
global knowledge_base, qa_chain
if not knowledge_base:
return "No PDF has been uploaded yet."
if not qa_chain:
return "QA chain is not initialized."
docs = knowledge_base.similarity_search(prompt)
response = qa_chain.run(input_documents=docs, question=prompt)
if "Helpful Answer:" in response:
response = response.split("Helpful Answer:")[1].strip()
return response
except Exception as e:
return f"Error: {str(e)}"
# ํŽ˜์ด์ง€ ์„ค์ •
if "page" not in st.session_state:
st.session_state.page = "main"
if "paper_name" not in st.session_state:
st.session_state.paper_name = ""
# ํŽ˜์ด์ง€ ๋ Œ๋”๋ง
if st.session_state.page == "main":
main_page()
elif st.session_state.page == "chat":
chat_page()