halyn commited on
Commit
3e1aa0b
1 Parent(s): a80fb91

code update

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -2,12 +2,13 @@ import io
2
  import streamlit as st
3
  from PyPDF2 import PdfReader
4
  from langchain.text_splitter import CharacterTextSplitter
5
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
6
- from langchain.vectorstores import FAISS
7
  from langchain.chains.question_answering import load_qa_chain
8
- from langchain.llms import HuggingFacePipeline
 
 
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
- from peft import PeftModel, PeftConfig
11
 
12
  # Global variables
13
  knowledge_base = None
@@ -38,11 +39,9 @@ def create_knowledge_base(chunks):
38
  def load_model():
39
  model_name = "halyn/gemma2-2b-it-finetuned-paperqa"
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
41
- config = PeftConfig.from_pretrained(model_name)
42
- model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
43
- model = PeftModel.from_pretrained(model, model_name)
44
-
45
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
 
46
  # QA 체인 설정
47
  def setup_qa_chain():
48
  global qa_chain
@@ -99,7 +98,6 @@ def main_page():
99
  st.error(f"Failed to process the PDF: {str(e)}")
100
 
101
 
102
-
103
  # 채팅 페이지 UI
104
  def chat_page():
105
  st.title(f"Ask anything about {st.session_state.paper_name}")
 
2
  import streamlit as st
3
  from PyPDF2 import PdfReader
4
  from langchain.text_splitter import CharacterTextSplitter
 
 
5
  from langchain.chains.question_answering import load_qa_chain
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_community.llms import HuggingFacePipeline
9
+
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+
12
 
13
  # Global variables
14
  knowledge_base = None
 
39
  def load_model():
40
  model_name = "halyn/gemma2-2b-it-finetuned-paperqa"
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
43
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
44
+
45
  # QA 체인 설정
46
  def setup_qa_chain():
47
  global qa_chain
 
98
  st.error(f"Failed to process the PDF: {str(e)}")
99
 
100
 
 
101
  # 채팅 페이지 UI
102
  def chat_page():
103
  st.title(f"Ask anything about {st.session_state.paper_name}")