import io import streamlit as st from PyPDF2 import PdfReader from langchain.text_splitter import CharacterTextSplitter from langchain.chains.question_answering import load_qa_chain from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.llms import HuggingFacePipeline from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # Global variables knowledge_base = None qa_chain = None # PDF 파일 로드 및 텍스트 추출 def load_pdf(pdf_file): pdf_reader = PdfReader(pdf_file) text = "".join(page.extract_text() for page in pdf_reader.pages) return text # 텍스트를 청크로 분할 def split_text(text): text_splitter = CharacterTextSplitter( separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len ) return text_splitter.split_text(text) # FAISS 벡터 저장소 생성 def create_knowledge_base(chunks): model_name = "sentence-transformers/all-mpnet-base-v2" # 임베딩 모델을 명시 embeddings = HuggingFaceEmbeddings(model_name=model_name) return FAISS.from_texts(chunks, embeddings) # Hugging Face 모델 로드 def load_model(): model_name = "halyn/gemma2-2b-it-finetuned-paperqa" # 텍스트 생성 모델 사용 tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=None, clean_up_tokenization_spaces=False) model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=None) return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1) # QA 체인 설정 def setup_qa_chain(): global qa_chain try: pipe = load_model() except Exception as e: print(f"Error loading model: {e}") return llm = HuggingFacePipeline(pipeline=pipe) qa_chain = load_qa_chain(llm, chain_type="stuff") # 메인 페이지 UI def main_page(): st.title("Welcome to GemmaPaperQA") st.subheader("Upload Your Paper") paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden") if paper: st.write(f"Upload complete! File name: {paper.name}") # 파일 크기 확인 file_size = paper.size # 파일 크기를 파일 포인터 이동 없이 확인 if file_size > 10 * 1024 * 1024: # 10MB 제한 st.error("File is too large! Please upload a file smaller than 10MB.") return # 중간 확인 절차 - PDF 내용 미리보기 with st.spinner('Processing PDF...'): try: paper.seek(0) # 파일 읽기 포인터를 처음으로 되돌림 contents = paper.read() pdf_file = io.BytesIO(contents) text = load_pdf(pdf_file) # 텍스트가 추출되지 않을 경우 에러 처리 if len(text.strip()) == 0: st.error("The PDF appears to have no extractable text. Please check the file and try again.") return st.text_area("Preview of extracted text", text[:1000], height=200) st.write(f"Total characters extracted: {len(text)}") global knowledge_base if st.button("Proceed with this file"): chunks = split_text(text) knowledge_base = create_knowledge_base(chunks) if knowledge_base is None: st.error("Failed to create knowledge base.") return setup_qa_chain() st.session_state.paper_name = paper.name[:-4] st.session_state.page = "chat" st.success("PDF successfully processed! You can now ask questions.") except Exception as e: st.error(f"Failed to process the PDF: {str(e)}") # 채팅 페이지 UI def chat_page(): st.title(f"Ask anything about {st.session_state.paper_name}") if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Chat here!"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) response = get_response_from_model(prompt) with st.chat_message("assistant"): st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) if st.button("Go back to main page"): st.session_state.page = "main" # 모델 응답 처리 def get_response_from_model(prompt): try: global knowledge_base, qa_chain if not knowledge_base: return "No PDF has been uploaded yet." if not qa_chain: return "QA chain is not initialized." docs = knowledge_base.similarity_search(prompt) response = qa_chain.run(input_documents=docs, question=prompt) if "Helpful Answer:" in response: response = response.split("Helpful Answer:")[1].strip() return response except Exception as e: return f"Error: {str(e)}" # 페이지 설정 if "page" not in st.session_state: st.session_state.page = "main" if "paper_name" not in st.session_state: st.session_state.paper_name = "" # 페이지 렌더링 if st.session_state.page == "main": main_page() elif st.session_state.page == "chat": chat_page()