gemma_paper_qa / app.py
halyn's picture
remove dotenv
e2ce39d
raw
history blame
4.46 kB
import os
import io
import requests
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Global variables
knowledge_base = None
qa_chain = None
def load_pdf(pdf_file):
"""
Load and extract text from a PDF.
"""
pdf_reader = PdfReader(pdf_file)
text = "".join(page.extract_text() for page in pdf_reader.pages)
return text
def split_text(text):
"""
Split the extracted text into chunks.
"""
text_splitter = CharacterTextSplitter(
separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len
)
return text_splitter.split_text(text)
def create_knowledge_base(chunks):
"""
Create a FAISS knowledge base from text chunks.
"""
embeddings = HuggingFaceEmbeddings()
return FAISS.from_texts(chunks, embeddings)
def load_model(model_path):
"""
Load the HuggingFace model and tokenizer, and create a text-generation pipeline.
"""
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
def setup_qa_chain():
"""
Set up the question-answering chain.
"""
global qa_chain
pipe = load_model(MODEL_PATH)
llm = HuggingFacePipeline(pipeline=pipe)
qa_chain = load_qa_chain(llm, chain_type="stuff")
# Streamlit UI
def main_page():
st.title("Welcome to GemmaPaperQA")
st.subheader("Upload Your Paper")
paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden")
if paper:
st.write(f"Upload complete! File name is {paper.name}")
st.write("Please click the button below.")
if st.button("Click Here :)"):
try:
# PDF ํŒŒ์ผ ์ฒ˜๋ฆฌ
contents = paper.read()
pdf_file = io.BytesIO(contents)
text = load_pdf(pdf_file)
chunks = split_text(text)
global knowledge_base
knowledge_base = create_knowledge_base(chunks)
st.success("PDF successfully processed! You can now ask questions.")
st.session_state.paper_name = paper.name[:-4]
st.session_state.page = "chat"
setup_qa_chain()
except Exception as e:
st.error(f"Failed to process the PDF: {str(e)}")
def chat_page():
st.title(f"Ask anything about {st.session_state.paper_name}")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Chat here!"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
response = get_response_from_model(prompt)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if st.button("Go back to main page"):
st.session_state.page = "main"
def get_response_from_model(prompt):
try:
global knowledge_base, qa_chain
if not knowledge_base:
return "No PDF has been uploaded yet."
if not qa_chain:
return "QA chain is not initialized."
docs = knowledge_base.similarity_search(prompt)
response = qa_chain.run(input_documents=docs, question=prompt)
if "Helpful Answer:" in response:
response = response.split("Helpful Answer:")[1].strip()
return response
except Exception as e:
return f"Error: {str(e)}"
# Streamlit - ์ดˆ๊ธฐ ํŽ˜์ด์ง€ ์„ค์ •
if "page" not in st.session_state:
st.session_state.page = "main"
# paper_name ์ดˆ๊ธฐํ™”
if "paper_name" not in st.session_state:
st.session_state.paper_name = ""
# ํŽ˜์ด์ง€ ๋ Œ๋”๋ง
if st.session_state.page == "main":
main_page()
elif st.session_state.page == "chat":
chat_page()