Add "Search Only" to OpenAI model options and make OpenAI API key input optional
#1
by
shinichi-a
- opened
app.py
CHANGED
@@ -1,23 +1,15 @@
|
|
1 |
-
"""
|
2 |
-
streamlit run app.py --server.address 0.0.0.0
|
3 |
-
"""
|
4 |
-
|
5 |
from __future__ import annotations
|
6 |
|
7 |
-
import streamlit as st
|
8 |
import os
|
9 |
-
|
10 |
-
import faiss
|
11 |
-
from sentence_transformers import SentenceTransformer
|
12 |
import torch
|
13 |
-
|
14 |
import streamlit as st
|
15 |
-
import pandas as pd
|
16 |
-
import os
|
17 |
from time import time
|
|
|
|
|
|
|
18 |
from datasets.download import DownloadManager
|
19 |
-
from datasets import load_dataset # type: ignore
|
20 |
-
|
21 |
|
22 |
WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
|
23 |
WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
|
@@ -36,6 +28,7 @@ EMB_MODEL_NAMES = list(EMB_MODEL_PQ.keys())
|
|
36 |
OPENAI_MODEL_NAMES = [
|
37 |
"gpt-3.5-turbo-1106",
|
38 |
"gpt-4-1106-preview",
|
|
|
39 |
]
|
40 |
|
41 |
E5_QUERY_TYPES = [
|
@@ -60,7 +53,6 @@ Responses must be given in Japanese.
|
|
60 |
{question}
|
61 |
""".strip()
|
62 |
|
63 |
-
|
64 |
if os.getenv("SPACE_ID"):
|
65 |
USE_HF_SPACE = True
|
66 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
@@ -68,9 +60,7 @@ if os.getenv("SPACE_ID"):
|
|
68 |
else:
|
69 |
USE_HF_SPACE = False
|
70 |
|
71 |
-
# for tokenizer
|
72 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
73 |
-
|
74 |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
75 |
|
76 |
|
@@ -81,6 +71,7 @@ def get_model(name: str, max_seq_length=512):
|
|
81 |
device = "cuda"
|
82 |
elif torch.backends.mps.is_available():
|
83 |
device = "mps"
|
|
|
84 |
model = SentenceTransformer(name, device=device)
|
85 |
model.max_seq_length = max_seq_length
|
86 |
return model
|
@@ -93,9 +84,7 @@ def get_wikija_ds(name: str = WIKIPEDIA_JS_DS_NAME):
|
|
93 |
|
94 |
|
95 |
@st.cache_resource
|
96 |
-
def get_faiss_index(
|
97 |
-
index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME
|
98 |
-
):
|
99 |
target_path = f"faiss_indexes/{name}/{index_name}"
|
100 |
dm = DownloadManager()
|
101 |
index_local_path = dm.download(
|
@@ -110,9 +99,7 @@ def text_to_emb(model, text: str, prefix: str):
|
|
110 |
return model.encode([prefix + text], normalize_embeddings=True)
|
111 |
|
112 |
|
113 |
-
def search(
|
114 |
-
faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int
|
115 |
-
):
|
116 |
start_time = time()
|
117 |
emb = text_to_emb(emb_model, question, search_text_prefix)
|
118 |
emb_exec_time = time() - start_time
|
@@ -121,7 +108,7 @@ def search(
|
|
121 |
scores = scores[0]
|
122 |
indexes = indexes[0]
|
123 |
results = []
|
124 |
-
for idx, score in zip(indexes, scores):
|
125 |
idx = int(idx)
|
126 |
passage = ds[idx]
|
127 |
results.append((score, passage))
|
@@ -133,7 +120,6 @@ def to_contexts(passages):
|
|
133 |
for passage in passages:
|
134 |
title = passage["title"]
|
135 |
text = passage["text"]
|
136 |
-
# section = passage["section"]
|
137 |
contexts += f"- {title}: {text}\n"
|
138 |
return contexts
|
139 |
|
@@ -211,15 +197,13 @@ def app():
|
|
211 |
key="question",
|
212 |
value="楽曲『約束はいらない』でデビューした、声優は誰?",
|
213 |
)
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
else:
|
222 |
-
st.session_state.openai_api_key = OPENAI_API_KEY
|
223 |
|
224 |
with st.expander("オプション"):
|
225 |
option_cols_main = st.columns(2)
|
@@ -229,6 +213,8 @@ def app():
|
|
229 |
st.selectbox(
|
230 |
"OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
|
231 |
)
|
|
|
|
|
232 |
emb_model_name = st.session_state.emb_model_name
|
233 |
option_cols_sub = st.columns(2)
|
234 |
with option_cols_sub[0]:
|
@@ -300,10 +286,10 @@ def app():
|
|
300 |
st.dataframe(df, hide_index=True)
|
301 |
|
302 |
openai_api_key = st.session_state.openai_api_key
|
303 |
-
|
|
|
304 |
openai_api_key = openai_api_key.strip()
|
305 |
answer_header.subheader("Answer: ")
|
306 |
-
openai_model_name = st.session_state.openai_model_name
|
307 |
temperature = st.session_state.temperature
|
308 |
qa_prompt = st.session_state.qa_prompt
|
309 |
max_tokens = st.session_state.max_tokens
|
@@ -320,4 +306,4 @@ def app():
|
|
320 |
|
321 |
|
322 |
if __name__ == "__main__":
|
323 |
-
app()
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
import os
|
4 |
+
import pandas as pd
|
|
|
|
|
5 |
import torch
|
6 |
+
import faiss
|
7 |
import streamlit as st
|
|
|
|
|
8 |
from time import time
|
9 |
+
from openai import OpenAI
|
10 |
+
from sentence_transformers import SentenceTransformer
|
11 |
+
from datasets import load_dataset
|
12 |
from datasets.download import DownloadManager
|
|
|
|
|
13 |
|
14 |
WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
|
15 |
WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
|
|
|
28 |
OPENAI_MODEL_NAMES = [
|
29 |
"gpt-3.5-turbo-1106",
|
30 |
"gpt-4-1106-preview",
|
31 |
+
"Search Only",
|
32 |
]
|
33 |
|
34 |
E5_QUERY_TYPES = [
|
|
|
53 |
{question}
|
54 |
""".strip()
|
55 |
|
|
|
56 |
if os.getenv("SPACE_ID"):
|
57 |
USE_HF_SPACE = True
|
58 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
|
|
60 |
else:
|
61 |
USE_HF_SPACE = False
|
62 |
|
|
|
63 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
64 |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
65 |
|
66 |
|
|
|
71 |
device = "cuda"
|
72 |
elif torch.backends.mps.is_available():
|
73 |
device = "mps"
|
74 |
+
|
75 |
model = SentenceTransformer(name, device=device)
|
76 |
model.max_seq_length = max_seq_length
|
77 |
return model
|
|
|
84 |
|
85 |
|
86 |
@st.cache_resource
|
87 |
+
def get_faiss_index(index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME):
|
|
|
|
|
88 |
target_path = f"faiss_indexes/{name}/{index_name}"
|
89 |
dm = DownloadManager()
|
90 |
index_local_path = dm.download(
|
|
|
99 |
return model.encode([prefix + text], normalize_embeddings=True)
|
100 |
|
101 |
|
102 |
+
def search(faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int):
|
|
|
|
|
103 |
start_time = time()
|
104 |
emb = text_to_emb(emb_model, question, search_text_prefix)
|
105 |
emb_exec_time = time() - start_time
|
|
|
108 |
scores = scores[0]
|
109 |
indexes = indexes[0]
|
110 |
results = []
|
111 |
+
for idx, score in zip(indexes, scores):
|
112 |
idx = int(idx)
|
113 |
passage = ds[idx]
|
114 |
results.append((score, passage))
|
|
|
120 |
for passage in passages:
|
121 |
title = passage["title"]
|
122 |
text = passage["text"]
|
|
|
123 |
contexts += f"- {title}: {text}\n"
|
124 |
return contexts
|
125 |
|
|
|
197 |
key="question",
|
198 |
value="楽曲『約束はいらない』でデビューした、声優は誰?",
|
199 |
)
|
200 |
+
st.text_input(
|
201 |
+
"OpenAI API Key",
|
202 |
+
key="openai_api_key",
|
203 |
+
type="password",
|
204 |
+
value=OPENAI_API_KEY if OPENAI_API_KEY else "",
|
205 |
+
placeholder="※ OpenAI API Key 未入力時は回答を生成せずに、検索のみ実行します",
|
206 |
+
)
|
|
|
|
|
207 |
|
208 |
with st.expander("オプション"):
|
209 |
option_cols_main = st.columns(2)
|
|
|
213 |
st.selectbox(
|
214 |
"OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
|
215 |
)
|
216 |
+
if "emb_model_name" not in st.session_state:
|
217 |
+
st.session_state.emb_model_name = EMB_MODEL_NAMES[0] # replace with the actual default value you want to use
|
218 |
emb_model_name = st.session_state.emb_model_name
|
219 |
option_cols_sub = st.columns(2)
|
220 |
with option_cols_sub[0]:
|
|
|
286 |
st.dataframe(df, hide_index=True)
|
287 |
|
288 |
openai_api_key = st.session_state.openai_api_key
|
289 |
+
openai_model_name = st.session_state.openai_model_name
|
290 |
+
if openai_api_key and openai_model_name != "Search Only":
|
291 |
openai_api_key = openai_api_key.strip()
|
292 |
answer_header.subheader("Answer: ")
|
|
|
293 |
temperature = st.session_state.temperature
|
294 |
qa_prompt = st.session_state.qa_prompt
|
295 |
max_tokens = st.session_state.max_tokens
|
|
|
306 |
|
307 |
|
308 |
if __name__ == "__main__":
|
309 |
+
app()
|