Add "Search Only" to OpenAI model options and make OpenAI API key input optional

#1
by shinichi-a - opened
Files changed (1) hide show
  1. app.py +22 -36
app.py CHANGED
@@ -1,23 +1,15 @@
1
- """
2
- streamlit run app.py --server.address 0.0.0.0
3
- """
4
-
5
  from __future__ import annotations
6
 
7
- import streamlit as st
8
  import os
9
-
10
- import faiss
11
- from sentence_transformers import SentenceTransformer
12
  import torch
13
- from openai import OpenAI
14
  import streamlit as st
15
- import pandas as pd
16
- import os
17
  from time import time
 
 
 
18
  from datasets.download import DownloadManager
19
- from datasets import load_dataset # type: ignore
20
-
21
 
22
  WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
23
  WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
@@ -36,6 +28,7 @@ EMB_MODEL_NAMES = list(EMB_MODEL_PQ.keys())
36
  OPENAI_MODEL_NAMES = [
37
  "gpt-3.5-turbo-1106",
38
  "gpt-4-1106-preview",
 
39
  ]
40
 
41
  E5_QUERY_TYPES = [
@@ -60,7 +53,6 @@ Responses must be given in Japanese.
60
  {question}
61
  """.strip()
62
 
63
-
64
  if os.getenv("SPACE_ID"):
65
  USE_HF_SPACE = True
66
  os.environ["HF_HOME"] = "/data/.huggingface"
@@ -68,9 +60,7 @@ if os.getenv("SPACE_ID"):
68
  else:
69
  USE_HF_SPACE = False
70
 
71
- # for tokenizer
72
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
73
-
74
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
75
 
76
 
@@ -81,6 +71,7 @@ def get_model(name: str, max_seq_length=512):
81
  device = "cuda"
82
  elif torch.backends.mps.is_available():
83
  device = "mps"
 
84
  model = SentenceTransformer(name, device=device)
85
  model.max_seq_length = max_seq_length
86
  return model
@@ -93,9 +84,7 @@ def get_wikija_ds(name: str = WIKIPEDIA_JS_DS_NAME):
93
 
94
 
95
  @st.cache_resource
96
- def get_faiss_index(
97
- index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME
98
- ):
99
  target_path = f"faiss_indexes/{name}/{index_name}"
100
  dm = DownloadManager()
101
  index_local_path = dm.download(
@@ -110,9 +99,7 @@ def text_to_emb(model, text: str, prefix: str):
110
  return model.encode([prefix + text], normalize_embeddings=True)
111
 
112
 
113
- def search(
114
- faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int
115
- ):
116
  start_time = time()
117
  emb = text_to_emb(emb_model, question, search_text_prefix)
118
  emb_exec_time = time() - start_time
@@ -121,7 +108,7 @@ def search(
121
  scores = scores[0]
122
  indexes = indexes[0]
123
  results = []
124
- for idx, score in zip(indexes, scores): # type: ignore
125
  idx = int(idx)
126
  passage = ds[idx]
127
  results.append((score, passage))
@@ -133,7 +120,6 @@ def to_contexts(passages):
133
  for passage in passages:
134
  title = passage["title"]
135
  text = passage["text"]
136
- # section = passage["section"]
137
  contexts += f"- {title}: {text}\n"
138
  return contexts
139
 
@@ -211,15 +197,13 @@ def app():
211
  key="question",
212
  value="楽曲『約束はいらない』でデビューした、声優は誰?",
213
  )
214
- if not OPENAI_API_KEY:
215
- st.text_input(
216
- "OpenAI API Key",
217
- key="openai_api_key",
218
- type="password",
219
- placeholder="※ OpenAI API Key 未入力時は回答を生成せずに、検索のみ実行します",
220
- )
221
- else:
222
- st.session_state.openai_api_key = OPENAI_API_KEY
223
 
224
  with st.expander("オプション"):
225
  option_cols_main = st.columns(2)
@@ -229,6 +213,8 @@ def app():
229
  st.selectbox(
230
  "OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
231
  )
 
 
232
  emb_model_name = st.session_state.emb_model_name
233
  option_cols_sub = st.columns(2)
234
  with option_cols_sub[0]:
@@ -300,10 +286,10 @@ def app():
300
  st.dataframe(df, hide_index=True)
301
 
302
  openai_api_key = st.session_state.openai_api_key
303
- if openai_api_key:
 
304
  openai_api_key = openai_api_key.strip()
305
  answer_header.subheader("Answer: ")
306
- openai_model_name = st.session_state.openai_model_name
307
  temperature = st.session_state.temperature
308
  qa_prompt = st.session_state.qa_prompt
309
  max_tokens = st.session_state.max_tokens
@@ -320,4 +306,4 @@ def app():
320
 
321
 
322
  if __name__ == "__main__":
323
- app()
 
 
 
 
 
1
  from __future__ import annotations
2
 
 
3
  import os
4
+ import pandas as pd
 
 
5
  import torch
6
+ import faiss
7
  import streamlit as st
 
 
8
  from time import time
9
+ from openai import OpenAI
10
+ from sentence_transformers import SentenceTransformer
11
+ from datasets import load_dataset
12
  from datasets.download import DownloadManager
 
 
13
 
14
  WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
15
  WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
 
28
  OPENAI_MODEL_NAMES = [
29
  "gpt-3.5-turbo-1106",
30
  "gpt-4-1106-preview",
31
+ "Search Only",
32
  ]
33
 
34
  E5_QUERY_TYPES = [
 
53
  {question}
54
  """.strip()
55
 
 
56
  if os.getenv("SPACE_ID"):
57
  USE_HF_SPACE = True
58
  os.environ["HF_HOME"] = "/data/.huggingface"
 
60
  else:
61
  USE_HF_SPACE = False
62
 
 
63
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
64
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
65
 
66
 
 
71
  device = "cuda"
72
  elif torch.backends.mps.is_available():
73
  device = "mps"
74
+
75
  model = SentenceTransformer(name, device=device)
76
  model.max_seq_length = max_seq_length
77
  return model
 
84
 
85
 
86
  @st.cache_resource
87
+ def get_faiss_index(index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME):
 
 
88
  target_path = f"faiss_indexes/{name}/{index_name}"
89
  dm = DownloadManager()
90
  index_local_path = dm.download(
 
99
  return model.encode([prefix + text], normalize_embeddings=True)
100
 
101
 
102
+ def search(faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int):
 
 
103
  start_time = time()
104
  emb = text_to_emb(emb_model, question, search_text_prefix)
105
  emb_exec_time = time() - start_time
 
108
  scores = scores[0]
109
  indexes = indexes[0]
110
  results = []
111
+ for idx, score in zip(indexes, scores):
112
  idx = int(idx)
113
  passage = ds[idx]
114
  results.append((score, passage))
 
120
  for passage in passages:
121
  title = passage["title"]
122
  text = passage["text"]
 
123
  contexts += f"- {title}: {text}\n"
124
  return contexts
125
 
 
197
  key="question",
198
  value="楽曲『約束はいらない』でデビューした、声優は誰?",
199
  )
200
+ st.text_input(
201
+ "OpenAI API Key",
202
+ key="openai_api_key",
203
+ type="password",
204
+ value=OPENAI_API_KEY if OPENAI_API_KEY else "",
205
+ placeholder="※ OpenAI API Key 未入力時は回答を生成せずに、検索のみ実行します",
206
+ )
 
 
207
 
208
  with st.expander("オプション"):
209
  option_cols_main = st.columns(2)
 
213
  st.selectbox(
214
  "OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
215
  )
216
+ if "emb_model_name" not in st.session_state:
217
+ st.session_state.emb_model_name = EMB_MODEL_NAMES[0] # replace with the actual default value you want to use
218
  emb_model_name = st.session_state.emb_model_name
219
  option_cols_sub = st.columns(2)
220
  with option_cols_sub[0]:
 
286
  st.dataframe(df, hide_index=True)
287
 
288
  openai_api_key = st.session_state.openai_api_key
289
+ openai_model_name = st.session_state.openai_model_name
290
+ if openai_api_key and openai_model_name != "Search Only":
291
  openai_api_key = openai_api_key.strip()
292
  answer_header.subheader("Answer: ")
 
293
  temperature = st.session_state.temperature
294
  qa_prompt = st.session_state.qa_prompt
295
  max_tokens = st.session_state.max_tokens
 
306
 
307
 
308
  if __name__ == "__main__":
309
+ app()