halyn commited on
Commit
6692e0b
โ€ข
1 Parent(s): e2ce39d

code update

Browse files
Files changed (2) hide show
  1. app.py +51 -36
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import io
3
- import requests
4
  import streamlit as st
5
  from PyPDF2 import PdfReader
6
  from langchain.text_splitter import CharacterTextSplitter
@@ -14,74 +13,90 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
14
  knowledge_base = None
15
  qa_chain = None
16
 
 
17
  def load_pdf(pdf_file):
18
- """
19
- Load and extract text from a PDF.
20
- """
21
  pdf_reader = PdfReader(pdf_file)
22
  text = "".join(page.extract_text() for page in pdf_reader.pages)
23
  return text
24
 
 
25
  def split_text(text):
26
- """
27
- Split the extracted text into chunks.
28
- """
29
  text_splitter = CharacterTextSplitter(
30
  separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len
31
  )
32
  return text_splitter.split_text(text)
33
 
 
34
  def create_knowledge_base(chunks):
35
- """
36
- Create a FAISS knowledge base from text chunks.
37
- """
38
  embeddings = HuggingFaceEmbeddings()
39
  return FAISS.from_texts(chunks, embeddings)
40
 
41
- def load_model(model_path):
42
- """
43
- Load the HuggingFace model and tokenizer, and create a text-generation pipeline.
44
- """
45
- tokenizer = AutoTokenizer.from_pretrained(model_path)
46
- model = AutoModelForCausalLM.from_pretrained(model_path)
47
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
48
 
 
49
  def setup_qa_chain():
50
- """
51
- Set up the question-answering chain.
52
- """
53
  global qa_chain
54
- pipe = load_model(MODEL_PATH)
55
  llm = HuggingFacePipeline(pipeline=pipe)
56
  qa_chain = load_qa_chain(llm, chain_type="stuff")
 
57
 
58
- # Streamlit UI
 
59
  def main_page():
60
  st.title("Welcome to GemmaPaperQA")
61
  st.subheader("Upload Your Paper")
62
 
63
  paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden")
64
  if paper:
65
- st.write(f"Upload complete! File name is {paper.name}")
66
- st.write("Please click the button below.")
67
-
68
- if st.button("Click Here :)"):
 
 
 
 
 
69
  try:
70
- # PDF ํŒŒ์ผ ์ฒ˜๋ฆฌ
71
  contents = paper.read()
72
  pdf_file = io.BytesIO(contents)
73
  text = load_pdf(pdf_file)
74
- chunks = split_text(text)
75
- global knowledge_base
76
- knowledge_base = create_knowledge_base(chunks)
77
-
78
- st.success("PDF successfully processed! You can now ask questions.")
79
- st.session_state.paper_name = paper.name[:-4]
80
- st.session_state.page = "chat"
81
- setup_qa_chain()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  except Exception as e:
83
  st.error(f"Failed to process the PDF: {str(e)}")
84
 
 
 
 
85
  def chat_page():
86
  st.title(f"Ask anything about {st.session_state.paper_name}")
87
 
@@ -108,6 +123,7 @@ def chat_page():
108
  if st.button("Go back to main page"):
109
  st.session_state.page = "main"
110
 
 
111
  def get_response_from_model(prompt):
112
  try:
113
  global knowledge_base, qa_chain
@@ -126,11 +142,10 @@ def get_response_from_model(prompt):
126
  except Exception as e:
127
  return f"Error: {str(e)}"
128
 
129
- # Streamlit - ์ดˆ๊ธฐ ํŽ˜์ด์ง€ ์„ค์ •
130
  if "page" not in st.session_state:
131
  st.session_state.page = "main"
132
 
133
- # paper_name ์ดˆ๊ธฐํ™”
134
  if "paper_name" not in st.session_state:
135
  st.session_state.paper_name = ""
136
 
 
1
  import os
2
  import io
 
3
  import streamlit as st
4
  from PyPDF2 import PdfReader
5
  from langchain.text_splitter import CharacterTextSplitter
 
13
  knowledge_base = None
14
  qa_chain = None
15
 
16
+ # PDF ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ํ…์ŠคํŠธ ์ถ”์ถœ
17
  def load_pdf(pdf_file):
 
 
 
18
  pdf_reader = PdfReader(pdf_file)
19
  text = "".join(page.extract_text() for page in pdf_reader.pages)
20
  return text
21
 
22
+ # ํ…์ŠคํŠธ๋ฅผ ์ฒญํฌ๋กœ ๋ถ„ํ• 
23
  def split_text(text):
 
 
 
24
  text_splitter = CharacterTextSplitter(
25
  separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len
26
  )
27
  return text_splitter.split_text(text)
28
 
29
+ # FAISS ๋ฒกํ„ฐ ์ €์žฅ์†Œ ์ƒ์„ฑ
30
  def create_knowledge_base(chunks):
 
 
 
31
  embeddings = HuggingFaceEmbeddings()
32
  return FAISS.from_texts(chunks, embeddings)
33
 
34
+ # Hugging Face ๋ชจ๋ธ ๋กœ๋“œ
35
+ def load_model():
36
+ model_name = "halyn/gemma2-2b-it-finetuned-paperqa"
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
39
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
40
 
41
+ # QA ์ฒด์ธ ์„ค์ •
42
  def setup_qa_chain():
 
 
 
43
  global qa_chain
44
+ pipe = load_model()
45
  llm = HuggingFacePipeline(pipeline=pipe)
46
  qa_chain = load_qa_chain(llm, chain_type="stuff")
47
+
48
 
49
+
50
+ # ๋ฉ”์ธ ํŽ˜์ด์ง€ UI
51
  def main_page():
52
  st.title("Welcome to GemmaPaperQA")
53
  st.subheader("Upload Your Paper")
54
 
55
  paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden")
56
  if paper:
57
+ st.write(f"Upload complete! File name: {paper.name}")
58
+ # ํŒŒ์ผ ํฌ๊ธฐ ํ™•์ธ
59
+ file_size = paper.size # ํŒŒ์ผ ํฌ๊ธฐ๋ฅผ ํŒŒ์ผ ํฌ์ธํ„ฐ ์ด๋™ ์—†์ด ํ™•์ธ
60
+ if file_size > 10 * 1024 * 1024: # 10MB ์ œํ•œ
61
+ st.error("File is too large! Please upload a file smaller than 10MB.")
62
+ return
63
+
64
+ # ์ค‘๊ฐ„ ํ™•์ธ ์ ˆ์ฐจ - PDF ๋‚ด์šฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
65
+ with st.spinner('Processing PDF...'):
66
  try:
67
+ paper.seek(0) # ํŒŒ์ผ ์ฝ๊ธฐ ํฌ์ธํ„ฐ๋ฅผ ์ฒ˜์Œ์œผ๋กœ ๋˜๋Œ๋ฆผ
68
  contents = paper.read()
69
  pdf_file = io.BytesIO(contents)
70
  text = load_pdf(pdf_file)
71
+
72
+ # ํ…์ŠคํŠธ๊ฐ€ ์ถ”์ถœ๋˜์ง€ ์•Š์„ ๊ฒฝ์šฐ ์—๋Ÿฌ ์ฒ˜๋ฆฌ
73
+ if len(text.strip()) == 0:
74
+ st.error("The PDF appears to have no extractable text. Please check the file and try again.")
75
+ return
76
+
77
+ st.text_area("Preview of extracted text", text[:1000], height=200)
78
+ st.write(f"Total characters extracted: {len(text)}")
79
+
80
+ if st.button("Proceed with this file"):
81
+ chunks = split_text(text)
82
+ global knowledge_base
83
+ knowledge_base = create_knowledge_base(chunks)
84
+
85
+ if knowledge_base is None:
86
+ st.error("Failed to create knowledge base.")
87
+ return
88
+
89
+ st.session_state.paper_name = paper.name[:-4]
90
+ st.session_state.page = "chat"
91
+ setup_qa_chain()
92
+ st.success("PDF successfully processed! You can now ask questions.")
93
+
94
  except Exception as e:
95
  st.error(f"Failed to process the PDF: {str(e)}")
96
 
97
+
98
+
99
+ # ์ฑ„ํŒ… ํŽ˜์ด์ง€ UI
100
  def chat_page():
101
  st.title(f"Ask anything about {st.session_state.paper_name}")
102
 
 
123
  if st.button("Go back to main page"):
124
  st.session_state.page = "main"
125
 
126
+ # ๋ชจ๋ธ ์‘๋‹ต ์ฒ˜๋ฆฌ
127
  def get_response_from_model(prompt):
128
  try:
129
  global knowledge_base, qa_chain
 
142
  except Exception as e:
143
  return f"Error: {str(e)}"
144
 
145
+ # ํŽ˜์ด์ง€ ์„ค์ •
146
  if "page" not in st.session_state:
147
  st.session_state.page = "main"
148
 
 
149
  if "paper_name" not in st.session_state:
150
  st.session_state.paper_name = ""
151
 
requirements.txt CHANGED
@@ -5,4 +5,4 @@ transformers==4.31.0
5
  torch==2.0.1
6
  faiss-cpu==1.7.4
7
  requests==2.31.0
8
- huggingface-hub==0.16.4
 
5
  torch==2.0.1
6
  faiss-cpu==1.7.4
7
  requests==2.31.0
8
+ huggingface-hub==0.16.4