halyn commited on
Commit
f085c10
β€’
1 Parent(s): 80a4c83

update code

Browse files
Files changed (1) hide show
  1. app.py +217 -88
app.py CHANGED
@@ -1,93 +1,222 @@
1
- import streamlit as st
 
2
  import requests
 
 
 
 
3
  from PyPDF2 import PdfReader
 
 
 
 
 
 
 
4
 
5
-
6
- st.title("Welcome to GemmaPaperQA")
7
- st.subheader("Upload Your Paper")
8
-
9
- # def main_page():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
-
12
- # paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden")
13
- # if paper:
14
- # st.write(f"Upload complete! File name is {paper.name}")
15
- # st.write("Please click the button below.")
16
- # # pdf_reader = PdfReader(paper)
17
- # # for page in pdf_reader.pages:
18
- # # paper_title.append(page.extract_text())
19
- # # break
20
- # # paper_name = paper_title[0].split("\n")[0]
21
-
22
- # # st.subheader(f"You upload the <{paper_name}> paper")
23
-
24
- # if st.button("Click Here :)"):
25
- # # FastAPI μ„œλ²„μ— PDF 파일 전솑
26
- # try:
27
- # files = {"file": (paper.name, paper, "application/pdf")}
28
- # response = requests.post(f"{FASTAPI_URL}/upload_pdf", files=files)
29
- # if response.status_code == 200:
30
- # st.success("PDF successfully uploaded to the model! Please click the button again")
31
- # st.session_state.messages = []
32
- # st.session_state.paper_name = paper.name[:-4]
33
- # st.session_state.page = "chat"
34
- # else:
35
- # st.error(f"Failed to upload PDF to the model. Error: {response.text}")
36
- # except requests.RequestException as e:
37
- # st.error(f"Error connecting to the server: {str(e)}")
38
-
39
- # def chat_page():
40
- # st.title(f"Welcome to GemmaPaperQA")
41
- # st.subheader(f"Ask anything about {st.session_state.paper_name}")
42
-
43
- # if "messages" not in st.session_state:
44
- # st.session_state.messages = []
45
-
46
- # for message in st.session_state.messages:
47
- # with st.chat_message(message["role"]):
48
- # st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # if prompt := st.chat_input("Chat here !"):
51
- # # Add user message to chat history
52
- # st.session_state.messages.append({"role": "user", "content": prompt})
53
-
54
- # # Display user message in chat message container
55
- # with st.chat_message("user"):
56
- # st.markdown(prompt)
57
-
58
- # # Get response from FastAPI server
59
- # response = get_response_from_fastapi(prompt)
60
-
61
- # # Display assistant response in chat message container
62
- # with st.chat_message("assistant"):
63
- # st.markdown(response)
64
-
65
- # # Add assistant response to chat history
66
- # st.session_state.messages.append({"role": "assistant", "content": response})
67
-
68
- # if st.button("Go back to main page"):
69
- # st.session_state.page = "main"
70
-
71
- # def get_response_from_fastapi(prompt):
72
- # try:
73
- # response = requests.post(f"{FASTAPI_URL}/ask", json={"text": prompt})
74
- # if response.status_code == 200:
75
- # return response.json()["response"]
76
- # else:
77
- # return f"Sorry, I couldn't generate a response. Error: {response.text}"
78
- # except requests.RequestException as e:
79
- # return f"Sorry, there was an error connecting to the server: {str(e)}"
80
-
81
- # # 초기 νŽ˜μ΄μ§€ μ„€μ •
82
- # if "page" not in st.session_state:
83
- # st.session_state.page = "main"
84
-
85
- # # paper_name μ΄ˆκΈ°ν™”
86
- # if "paper_name" not in st.session_state:
87
- # st.session_state.paper_name = ""
88
-
89
- # # νŽ˜μ΄μ§€ λ Œλ”λ§
90
- # if st.session_state.page == "main":
91
- # main_page()
92
- # elif st.session_state.page == "chat":
93
- # chat_page()
 
 
 
 
 
 
1
+ import os
2
+ import io
3
  import requests
4
+ from dotenv import load_dotenv
5
+ from fastapi import FastAPI, HTTPException, UploadFile, File
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
  from PyPDF2 import PdfReader
9
+ from langchain.text_splitter import CharacterTextSplitter
10
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
11
+ from langchain.vectorstores import FAISS
12
+ from langchain.chains.question_answering import load_qa_chain
13
+ from langchain.llms import HuggingFacePipeline
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
15
+ import streamlit as st
16
 
17
+ # Disable WANDB
18
+ os.environ['WANDB_DISABLED'] = "true"
19
+
20
+ # Constants
21
+ MODEL_PATH = "/home/lab/halyn/gemma/halyn/paper/models/gemma-2-9b-it"
22
+ FASTAPI_URL = "http://203.249.64.50:8080" # μ„œλ²„ μ£Όμ†Œ
23
+
24
+ app = FastAPI()
25
+
26
+ # CORS μ„€μ •
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # λͺ¨λ“  좜처 ν—ˆμš©
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # Global variables to store the knowledge base and QA chain
36
+ knowledge_base = None
37
+ qa_chain = None
38
+
39
+ def load_pdf(pdf_file):
40
+ """
41
+ Load and extract text from a PDF.
42
+ Args:
43
+ pdf_file (str) : The PDF file.
44
+ Returns:
45
+ str: Extracted text from the PDF.
46
+ """
47
+ pdf_reader = PdfReader(pdf_file)
48
+ text = "".join(page.extract_text() for page in pdf_reader.pages)
49
+ return text
50
+
51
+ def split_text(text):
52
+ """
53
+ Split the extracted text into chunks.
54
+ Args:
55
+ text (str) : The full text extracted from the PDF.
56
+ Returns:
57
+ list : A list of text chunks
58
+ """
59
+ text_splitter = CharacterTextSplitter(
60
+ separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len
61
+ )
62
+ return text_splitter.split_text(text)
63
+
64
+ def create_knowledge_base(chunks):
65
+ """
66
+ Create a FAISS knowledge base from text chunks.
67
+ Args:
68
+ chunks (list) : A list of text chunks.
69
+ Returns:
70
+ FAISS: A FAISS knowledge base object
71
+ """
72
+ embeddings = HuggingFaceEmbeddings()
73
+ return FAISS.from_texts(chunks, embeddings)
74
+
75
+ def load_model(model_path):
76
+ """
77
+ Load the HuggingFace model and tokenizer, and create a text-generation pipeline.
78
+ Args:
79
+ model_path (str) : The path to the pre-trained model.
80
+ Returns:
81
+ pipeline: A HuggingFace pipeline for text generation.
82
+ """
83
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
84
+ model = AutoModelForCausalLM.from_pretrained(model_path)
85
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
86
+
87
+ @app.on_event("startup")
88
+ async def startup_event():
89
+ """ Start function to run the PDF question-answering system. """
90
+ global qa_chain
91
+ load_dotenv()
92
 
93
+ # Load the language model
94
+ try:
95
+ pipe = load_model(MODEL_PATH)
96
+ llm = HuggingFacePipeline(pipeline=pipe)
97
+ qa_chain = load_qa_chain(llm, chain_type="stuff")
98
+ except Exception as e:
99
+ print(f"Error loading model: {e}")
100
+ raise HTTPException(status_code=500, detail="Failed to load the language model")
101
+
102
+ @app.post("/upload_pdf")
103
+ async def upload_pdf(file: UploadFile = File(...)):
104
+ global knowledge_base
105
+ try:
106
+ contents = await file.read()
107
+ pdf_file = io.BytesIO(contents)
108
+ text = load_pdf(pdf_file)
109
+ chunks = split_text(text)
110
+ knowledge_base = create_knowledge_base(chunks)
111
+ return {"message": "PDF uploaded and processed successfully"}
112
+ except Exception as e:
113
+ raise HTTPException(status_code=400, detail=f"Failed to process PDF: {str(e)}")
114
+
115
+ class Question(BaseModel):
116
+ text: str
117
+
118
+ @app.post("/ask")
119
+ async def ask_question(question: Question):
120
+ global knowledge_base, qa_chain
121
+ if not knowledge_base:
122
+ raise HTTPException(status_code=400, detail="No PDF has been uploaded yet")
123
+ if not qa_chain:
124
+ raise HTTPException(status_code=500, detail="QA chain is not initialized")
125
+
126
+ try:
127
+ docs = knowledge_base.similarity_search(question.text)
128
+ response = qa_chain.run(input_documents=docs, question=question.text)
129
+
130
+ if "Helpful Answer:" in response:
131
+ response = response.split("Helpful Answer:")[1].strip()
132
+
133
+ return {"response": response}
134
+ except Exception as e:
135
+ raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
136
+
137
+
138
+ # Streamlit UI
139
+ def main_page():
140
+ st.title("Welcome to GemmaPaperQA")
141
+ st.subheader("Upload Your Paper")
142
+
143
+ paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden")
144
+ if paper:
145
+ st.write(f"Upload complete! File name is {paper.name}")
146
+ st.write("Please click the button below.")
147
+
148
+ if st.button("Click Here :)"):
149
+ # FastAPI μ„œλ²„μ— PDF 파일 전솑
150
+ try:
151
+ files = {"file": (paper.name, paper, "application/pdf")}
152
+ response = requests.post(f"{FASTAPI_URL}/upload_pdf", files=files)
153
+ if response.status_code == 200:
154
+ st.success("PDF successfully uploaded to the model! Please click the button again")
155
+ st.session_state.messages = []
156
+ st.session_state.paper_name = paper.name[:-4]
157
+ st.session_state.page = "chat"
158
+ else:
159
+ st.error(f"Failed to upload PDF to the model. Error: {response.text}")
160
+ except requests.RequestException as e:
161
+ st.error(f"Error connecting to the server: {str(e)}")
162
+
163
+ def chat_page():
164
+ st.title(f"Welcome to GemmaPaperQA")
165
+ st.subheader(f"Ask anything about {st.session_state.paper_name}")
166
+
167
+ if "messages" not in st.session_state:
168
+ st.session_state.messages = []
169
+
170
+ for message in st.session_state.messages:
171
+ with st.chat_message(message["role"]):
172
+ st.markdown(message["content"])
173
 
174
+ if prompt := st.chat_input("Chat here !"):
175
+ # Add user message to chat history
176
+ st.session_state.messages.append({"role": "user", "content": prompt})
177
+
178
+ # Display user message in chat message container
179
+ with st.chat_message("user"):
180
+ st.markdown(prompt)
181
+
182
+ # Get response from FastAPI server
183
+ response = get_response_from_fastapi(prompt)
184
+
185
+ # Display assistant response in chat message container
186
+ with st.chat_message("assistant"):
187
+ st.markdown(response)
188
+
189
+ # Add assistant response to chat history
190
+ st.session_state.messages.append({"role": "assistant", "content": response})
191
+
192
+ if st.button("Go back to main page"):
193
+ st.session_state.page = "main"
194
+
195
+ def get_response_from_fastapi(prompt):
196
+ try:
197
+ response = requests.post(f"{FASTAPI_URL}/ask", json={"text": prompt})
198
+ if response.status_code == 200:
199
+ return response.json()["response"]
200
+ else:
201
+ return f"Sorry, I couldn't generate a response. Error: {response.text}"
202
+ except requests.RequestException as e:
203
+ return f"Sorry, there was an error connecting to the server: {str(e)}"
204
+
205
+ # Streamlit - 초기 νŽ˜μ΄μ§€ μ„€μ •
206
+ if "page" not in st.session_state:
207
+ st.session_state.page = "main"
208
+
209
+ # paper_name μ΄ˆκΈ°ν™”
210
+ if "paper_name" not in st.session_state:
211
+ st.session_state.paper_name = ""
212
+
213
+ # νŽ˜μ΄μ§€ λ Œλ”λ§
214
+ if st.session_state.page == "main":
215
+ main_page()
216
+ elif st.session_state.page == "chat":
217
+ chat_page()
218
+
219
+ # FastAPI μ•± 싀행을 μœ„ν•œ μ½”λ“œ
220
+ if __name__ == "__main__":
221
+ import uvicorn
222
+ uvicorn.run(app, host="0.0.0.0", port=8050)