Spaces:
Running
Running
move settings on the sidebar, allow env variables
Browse files- streamlit_app.py +78 -57
streamlit_app.py
CHANGED
@@ -18,11 +18,14 @@ from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_t
|
|
18 |
from grobid_client_generic import GrobidClientGeneric
|
19 |
|
20 |
if 'rqa' not in st.session_state:
|
21 |
-
st.session_state['rqa'] =
|
22 |
|
23 |
if 'api_key' not in st.session_state:
|
24 |
st.session_state['api_key'] = False
|
25 |
|
|
|
|
|
|
|
26 |
if 'doc_id' not in st.session_state:
|
27 |
st.session_state['doc_id'] = None
|
28 |
|
@@ -42,13 +45,16 @@ if 'git_rev' not in st.session_state:
|
|
42 |
if "messages" not in st.session_state:
|
43 |
st.session_state.messages = []
|
44 |
|
|
|
|
|
|
|
45 |
|
46 |
def new_file():
|
47 |
st.session_state['loaded_embeddings'] = None
|
48 |
st.session_state['doc_id'] = None
|
49 |
|
50 |
|
51 |
-
@st.cache_resource
|
52 |
def init_qa(model):
|
53 |
if model == 'chatgpt-3.5-turbo':
|
54 |
chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
|
@@ -67,6 +73,7 @@ def init_qa(model):
|
|
67 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
68 |
else:
|
69 |
st.error("The model was not loaded properly. Try reloading. ")
|
|
|
70 |
|
71 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
72 |
|
@@ -94,7 +101,6 @@ def init_ner():
|
|
94 |
grobid_quantities_client=quantities_client,
|
95 |
grobid_superconductors_client=materials_client
|
96 |
)
|
97 |
-
|
98 |
return gqa
|
99 |
|
100 |
|
@@ -125,51 +131,52 @@ def play_old_messages():
|
|
125 |
|
126 |
is_api_key_provided = st.session_state['api_key']
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
|
141 |
-
api_key = st.
|
142 |
-
|
|
|
|
|
|
|
|
|
143 |
if api_key:
|
144 |
st.session_state['api_key'] = is_api_key_provided = True
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
147 |
elif model == 'chatgpt-3.5-turbo':
|
148 |
-
api_key = st.
|
149 |
-
|
|
|
|
|
150 |
if api_key:
|
151 |
st.session_state['api_key'] = is_api_key_provided = True
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
156 |
|
157 |
st.title("π Scientific Document Insight Q&A")
|
158 |
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
disabled=not is_api_key_provided,
|
164 |
-
help="The full-text is extracted using Grobid. ")
|
165 |
-
with radio_col:
|
166 |
-
mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0,
|
167 |
-
help="LLM will respond the question, Embedding will show the "
|
168 |
-
"paragraphs relevant to the question in the paper.")
|
169 |
-
with context_col:
|
170 |
-
context_size = st.slider("Context size", 3, 10, value=4,
|
171 |
-
help="Number of paragraphs to consider when answering a question",
|
172 |
-
disabled=not uploaded_file)
|
173 |
|
174 |
question = st.chat_input(
|
175 |
"Ask something about the article",
|
@@ -178,14 +185,29 @@ question = st.chat_input(
|
|
178 |
)
|
179 |
|
180 |
with st.sidebar:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
st.header("Documentation")
|
182 |
st.markdown("https://github.com/lfoppiano/document-qa")
|
183 |
st.markdown(
|
184 |
"""After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
|
185 |
|
186 |
-
st.markdown(
|
187 |
-
'**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
|
188 |
-
unsafe_allow_html=True)
|
189 |
if st.session_state['git_rev'] != "unknown":
|
190 |
st.markdown("**Revision number**: [" + st.session_state[
|
191 |
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
|
@@ -203,9 +225,9 @@ if uploaded_file and not st.session_state.loaded_embeddings:
|
|
203 |
tmp_file = NamedTemporaryFile()
|
204 |
tmp_file.write(bytearray(binary))
|
205 |
# hash = get_file_hash(tmp_file.name)[:10]
|
206 |
-
st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name,
|
207 |
-
|
208 |
-
|
209 |
st.session_state['loaded_embeddings'] = True
|
210 |
st.session_state.messages = []
|
211 |
|
@@ -226,27 +248,26 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
|
|
226 |
text_response = None
|
227 |
if mode == "Embeddings":
|
228 |
with st.spinner("Generating LLM response..."):
|
229 |
-
text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id,
|
230 |
-
|
231 |
elif mode == "LLM":
|
232 |
with st.spinner("Generating response..."):
|
233 |
-
_, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id,
|
234 |
-
|
235 |
|
236 |
if not text_response:
|
237 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
238 |
|
239 |
with st.chat_message("assistant"):
|
240 |
if mode == "LLM":
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
text_response = decorated_text
|
250 |
else:
|
251 |
st.write(text_response)
|
252 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
|
|
18 |
from grobid_client_generic import GrobidClientGeneric
|
19 |
|
20 |
if 'rqa' not in st.session_state:
|
21 |
+
st.session_state['rqa'] = {}
|
22 |
|
23 |
if 'api_key' not in st.session_state:
|
24 |
st.session_state['api_key'] = False
|
25 |
|
26 |
+
if 'api_keys' not in st.session_state:
|
27 |
+
st.session_state['api_keys'] = {}
|
28 |
+
|
29 |
if 'doc_id' not in st.session_state:
|
30 |
st.session_state['doc_id'] = None
|
31 |
|
|
|
45 |
if "messages" not in st.session_state:
|
46 |
st.session_state.messages = []
|
47 |
|
48 |
+
if 'ner_processing' not in st.session_state:
|
49 |
+
st.session_state['ner_processing'] = False
|
50 |
+
|
51 |
|
52 |
def new_file():
|
53 |
st.session_state['loaded_embeddings'] = None
|
54 |
st.session_state['doc_id'] = None
|
55 |
|
56 |
|
57 |
+
# @st.cache_resource
|
58 |
def init_qa(model):
|
59 |
if model == 'chatgpt-3.5-turbo':
|
60 |
chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
|
|
|
73 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
74 |
else:
|
75 |
st.error("The model was not loaded properly. Try reloading. ")
|
76 |
+
st.stop()
|
77 |
|
78 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
79 |
|
|
|
101 |
grobid_quantities_client=quantities_client,
|
102 |
grobid_superconductors_client=materials_client
|
103 |
)
|
|
|
104 |
return gqa
|
105 |
|
106 |
|
|
|
131 |
|
132 |
is_api_key_provided = st.session_state['api_key']
|
133 |
|
134 |
+
with st.sidebar:
|
135 |
+
model = st.radio(
|
136 |
+
"Model (cannot be changed after selection or upload)",
|
137 |
+
("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
|
138 |
+
index=1,
|
139 |
+
captions=[
|
140 |
+
"ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
|
141 |
+
"Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
|
142 |
+
# "LLama2-70B-Chat + Sentence BERT (embeddings)",
|
143 |
+
],
|
144 |
+
help="Select the model you want to use.")
|
145 |
+
|
146 |
if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
|
147 |
+
api_key = st.text_input('Huggingface API Key',
|
148 |
+
type="password") if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ[
|
149 |
+
'HUGGINGFACEHUB_API_TOKEN']
|
150 |
+
st.markdown(
|
151 |
+
"Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
|
152 |
+
|
153 |
if api_key:
|
154 |
st.session_state['api_key'] = is_api_key_provided = True
|
155 |
+
st.session_state['api_keys']['mistral-7b-instruct-v0.1'] = api_key
|
156 |
+
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
|
157 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
|
158 |
+
st.session_state['rqa'][model] = init_qa(model)
|
159 |
+
|
160 |
elif model == 'chatgpt-3.5-turbo':
|
161 |
+
api_key = st.text_input('OpenAI API Key', type="password") if 'OPENAI_API_KEY' not in os.environ else \
|
162 |
+
os.environ['OPENAI_API_KEY']
|
163 |
+
st.markdown(
|
164 |
+
"Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
|
165 |
if api_key:
|
166 |
st.session_state['api_key'] = is_api_key_provided = True
|
167 |
+
st.session_state['api_keys']['chatgpt-3.5-turbo'] = api_key
|
168 |
+
if 'OPENAI_API_KEY' not in os.environ:
|
169 |
+
os.environ['OPENAI_API_KEY'] = api_key
|
170 |
+
st.session_state['rqa'][model] = init_qa(model)
|
171 |
+
# else:
|
172 |
+
# is_api_key_provided = st.session_state['api_key']
|
173 |
|
174 |
st.title("π Scientific Document Insight Q&A")
|
175 |
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
|
176 |
|
177 |
+
uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
|
178 |
+
disabled=not is_api_key_provided,
|
179 |
+
help="The full-text is extracted using Grobid. ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
question = st.chat_input(
|
182 |
"Ask something about the article",
|
|
|
185 |
)
|
186 |
|
187 |
with st.sidebar:
|
188 |
+
st.header("Settings")
|
189 |
+
mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0, horizontal=True,
|
190 |
+
help="LLM will respond the question, Embedding will show the "
|
191 |
+
"paragraphs relevant to the question in the paper.")
|
192 |
+
chunk_size = st.slider("Chunks size", 100, 2000, value=250,
|
193 |
+
help="Size of chunks in which the document is partitioned",
|
194 |
+
disabled=not uploaded_file)
|
195 |
+
context_size = st.slider("Context size", 3, 10, value=4,
|
196 |
+
help="Number of chunks to consider when answering a question",
|
197 |
+
disabled=not uploaded_file)
|
198 |
+
|
199 |
+
st.session_state['ner_processing'] = st.checkbox("NER processing on LLM response")
|
200 |
+
st.markdown(
|
201 |
+
'**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
|
202 |
+
unsafe_allow_html=True)
|
203 |
+
|
204 |
+
st.divider()
|
205 |
+
|
206 |
st.header("Documentation")
|
207 |
st.markdown("https://github.com/lfoppiano/document-qa")
|
208 |
st.markdown(
|
209 |
"""After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
|
210 |
|
|
|
|
|
|
|
211 |
if st.session_state['git_rev'] != "unknown":
|
212 |
st.markdown("**Revision number**: [" + st.session_state[
|
213 |
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
|
|
|
225 |
tmp_file = NamedTemporaryFile()
|
226 |
tmp_file.write(bytearray(binary))
|
227 |
# hash = get_file_hash(tmp_file.name)[:10]
|
228 |
+
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
|
229 |
+
chunk_size=chunk_size,
|
230 |
+
perc_overlap=0.1)
|
231 |
st.session_state['loaded_embeddings'] = True
|
232 |
st.session_state.messages = []
|
233 |
|
|
|
248 |
text_response = None
|
249 |
if mode == "Embeddings":
|
250 |
with st.spinner("Generating LLM response..."):
|
251 |
+
text_response = st.session_state['rqa'][model].query_storage(question, st.session_state.doc_id,
|
252 |
+
context_size=context_size)
|
253 |
elif mode == "LLM":
|
254 |
with st.spinner("Generating response..."):
|
255 |
+
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
256 |
+
context_size=context_size)
|
257 |
|
258 |
if not text_response:
|
259 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
260 |
|
261 |
with st.chat_message("assistant"):
|
262 |
if mode == "LLM":
|
263 |
+
if st.session_state['ner_processing']:
|
264 |
+
with st.spinner("Processing NER on LLM response..."):
|
265 |
+
entities = gqa.process_single_text(text_response)
|
266 |
+
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
|
267 |
+
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
|
268 |
+
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
|
269 |
+
text_response = decorated_text
|
270 |
+
st.markdown(text_response, unsafe_allow_html=True)
|
|
|
271 |
else:
|
272 |
st.write(text_response)
|
273 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|