lfoppiano commited on
Commit
20a04c7
2 Parent(s): 17c679c 9e0de2a

Merge branch 'main' into add-pdf-viewer

Browse files
Files changed (1) hide show
  1. streamlit_app.py +16 -19
streamlit_app.py CHANGED
@@ -23,6 +23,13 @@ OPENAI_MODELS = ['chatgpt-3.5-turbo',
23
  "gpt-4",
24
  "gpt-4-1106-preview"]
25
 
 
 
 
 
 
 
 
26
  if 'rqa' not in st.session_state:
27
  st.session_state['rqa'] = {}
28
 
@@ -142,18 +149,14 @@ def init_qa(model, api_key=None):
142
  frequency_penalty=0.1)
143
  embeddings = OpenAIEmbeddings()
144
 
145
- elif model == 'mistral-7b-instruct-v0.1':
146
- chat = HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.1",
147
- model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
 
 
148
  embeddings = HuggingFaceEmbeddings(
149
  model_name="all-MiniLM-L6-v2")
150
- st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
151
-
152
- elif model == 'zephyr-7b-beta':
153
- chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
154
- model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
155
- embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
156
- st.session_state['memory'] = None
157
  else:
158
  st.error("The model was not loaded properly. Try reloading. ")
159
  st.stop()
@@ -218,14 +221,8 @@ def play_old_messages():
218
  with st.sidebar:
219
  st.session_state['model'] = model = st.selectbox(
220
  "Model:",
221
- options=[
222
- "chatgpt-3.5-turbo",
223
- "mistral-7b-instruct-v0.1",
224
- "zephyr-7b-beta",
225
- "gpt-4",
226
- "gpt-4-1106-preview"
227
- ],
228
- index=2,
229
  placeholder="Select model",
230
  help="Select the LLM model:",
231
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
@@ -234,7 +231,7 @@ with st.sidebar:
234
  st.markdown(
235
  ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa/tree/review-interface#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
236
 
237
- if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
238
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
239
  api_key = st.text_input('Huggingface API Key', type="password")
240
 
 
23
  "gpt-4",
24
  "gpt-4-1106-preview"]
25
 
26
+ OPEN_MODELS = {
27
+ 'mistral-7b-instruct-v0.1': 'mistralai/Mistral-7B-Instruct-v0.1',
28
+ "zephyr-7b-beta": 'HuggingFaceH4/zephyr-7b-beta'
29
+ }
30
+
31
+ DISABLE_MEMORY = ['zephyr-7b-beta']
32
+
33
  if 'rqa' not in st.session_state:
34
  st.session_state['rqa'] = {}
35
 
 
149
  frequency_penalty=0.1)
150
  embeddings = OpenAIEmbeddings()
151
 
152
+ elif model in OPEN_MODELS:
153
+ chat = HuggingFaceHub(
154
+ repo_id=OPEN_MODELS[model],
155
+ model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048}
156
+ )
157
  embeddings = HuggingFaceEmbeddings(
158
  model_name="all-MiniLM-L6-v2")
159
+ st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
 
 
 
 
 
 
160
  else:
161
  st.error("The model was not loaded properly. Try reloading. ")
162
  st.stop()
 
221
  with st.sidebar:
222
  st.session_state['model'] = model = st.selectbox(
223
  "Model:",
224
+ options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
225
+ index=4,
 
 
 
 
 
 
226
  placeholder="Select model",
227
  help="Select the LLM model:",
228
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
 
231
  st.markdown(
232
  ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa/tree/review-interface#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
233
 
234
+ if (model in OPEN_MODELS) and model not in st.session_state['api_keys']:
235
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
236
  api_key = st.text_input('Huggingface API Key', type="password")
237