Spaces:
Paused
Paused
import io | |
import streamlit as st | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.llms import HuggingFacePipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from peft import PeftModel, PeftConfig | |
# Global variables | |
knowledge_base = None | |
qa_chain = None | |
# PDF ํ์ผ ๋ก๋ ๋ฐ ํ ์คํธ ์ถ์ถ | |
def load_pdf(pdf_file): | |
pdf_reader = PdfReader(pdf_file) | |
text = "".join(page.extract_text() for page in pdf_reader.pages) | |
return text | |
# ํ ์คํธ๋ฅผ ์ฒญํฌ๋ก ๋ถํ | |
def split_text(text): | |
text_splitter = CharacterTextSplitter( | |
separator="\n", | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len | |
) | |
return text_splitter.split_text(text) | |
# FAISS ๋ฒกํฐ ์ ์ฅ์ ์์ฑ | |
def create_knowledge_base(chunks): | |
embeddings = HuggingFaceEmbeddings() | |
return FAISS.from_texts(chunks, embeddings) | |
# Hugging Face ๋ชจ๋ธ ๋ก๋ | |
def load_model(): | |
model_name = "halyn/gemma2-2b-it-finetuned-paperqa" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
config = PeftConfig.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path) | |
model = PeftModel.from_pretrained(model, model_name) | |
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1) | |
# QA ์ฒด์ธ ์ค์ | |
def setup_qa_chain(): | |
global qa_chain | |
pipe = load_model() | |
llm = HuggingFacePipeline(pipeline=pipe) | |
qa_chain = load_qa_chain(llm, chain_type="stuff") | |
# ๋ฉ์ธ ํ์ด์ง UI | |
def main_page(): | |
st.title("Welcome to GemmaPaperQA") | |
st.subheader("Upload Your Paper") | |
paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden") | |
if paper: | |
st.write(f"Upload complete! File name: {paper.name}") | |
# ํ์ผ ํฌ๊ธฐ ํ์ธ | |
file_size = paper.size # ํ์ผ ํฌ๊ธฐ๋ฅผ ํ์ผ ํฌ์ธํฐ ์ด๋ ์์ด ํ์ธ | |
if file_size > 10 * 1024 * 1024: # 10MB ์ ํ | |
st.error("File is too large! Please upload a file smaller than 10MB.") | |
return | |
# ์ค๊ฐ ํ์ธ ์ ์ฐจ - PDF ๋ด์ฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ | |
with st.spinner('Processing PDF...'): | |
try: | |
paper.seek(0) # ํ์ผ ์ฝ๊ธฐ ํฌ์ธํฐ๋ฅผ ์ฒ์์ผ๋ก ๋๋๋ฆผ | |
contents = paper.read() | |
pdf_file = io.BytesIO(contents) | |
text = load_pdf(pdf_file) | |
# ํ ์คํธ๊ฐ ์ถ์ถ๋์ง ์์ ๊ฒฝ์ฐ ์๋ฌ ์ฒ๋ฆฌ | |
if len(text.strip()) == 0: | |
st.error("The PDF appears to have no extractable text. Please check the file and try again.") | |
return | |
st.text_area("Preview of extracted text", text[:1000], height=200) | |
st.write(f"Total characters extracted: {len(text)}") | |
if st.button("Proceed with this file"): | |
chunks = split_text(text) | |
global knowledge_base | |
knowledge_base = create_knowledge_base(chunks) | |
if knowledge_base is None: | |
st.error("Failed to create knowledge base.") | |
return | |
st.session_state.paper_name = paper.name[:-4] | |
st.session_state.page = "chat" | |
setup_qa_chain() | |
st.success("PDF successfully processed! You can now ask questions.") | |
except Exception as e: | |
st.error(f"Failed to process the PDF: {str(e)}") | |
# ์ฑํ ํ์ด์ง UI | |
def chat_page(): | |
st.title(f"Ask anything about {st.session_state.paper_name}") | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("Chat here!"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
response = get_response_from_model(prompt) | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
if st.button("Go back to main page"): | |
st.session_state.page = "main" | |
# ๋ชจ๋ธ ์๋ต ์ฒ๋ฆฌ | |
def get_response_from_model(prompt): | |
try: | |
global knowledge_base, qa_chain | |
if not knowledge_base: | |
return "No PDF has been uploaded yet." | |
if not qa_chain: | |
return "QA chain is not initialized." | |
docs = knowledge_base.similarity_search(prompt) | |
response = qa_chain.run(input_documents=docs, question=prompt) | |
if "Helpful Answer:" in response: | |
response = response.split("Helpful Answer:")[1].strip() | |
return response | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# ํ์ด์ง ์ค์ | |
if "page" not in st.session_state: | |
st.session_state.page = "main" | |
if "paper_name" not in st.session_state: | |
st.session_state.paper_name = "" | |
# ํ์ด์ง ๋ ๋๋ง | |
if st.session_state.page == "main": | |
main_page() | |
elif st.session_state.page == "chat": | |
chat_page() | |