Spaces:
Running
Running
initial
Browse files
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
from query import VectaraQuery
|
4 |
+
import streamlit as st
|
5 |
+
import os
|
6 |
+
|
7 |
+
def isTrue(x) -> bool:
|
8 |
+
if isinstance(x, bool):
|
9 |
+
return x
|
10 |
+
return x.strip().lower() == 'true'
|
11 |
+
|
12 |
+
def launch_bot():
|
13 |
+
def generate_response(question, role, topic):
|
14 |
+
response = vq.submit_query(question, role, topic)
|
15 |
+
return response
|
16 |
+
|
17 |
+
if 'cfg' not in st.session_state:
|
18 |
+
cfg = OmegaConf.create({
|
19 |
+
'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']),
|
20 |
+
'corpus_id': str(os.environ['VECTARA_CORPUS_ID']),
|
21 |
+
'api_key': str(os.environ['VECTARA_API_KEY']),
|
22 |
+
'prompt_name': 'vectara-experimental-summary-ext-2023-12-11-large',
|
23 |
+
'topic': 'standardized testing in education',
|
24 |
+
'human_role': 'in opposition to',
|
25 |
+
'bot_role': 'in support of'
|
26 |
+
})
|
27 |
+
st.session_state.cfg = cfg
|
28 |
+
st.session_state.vq = VectaraQuery(cfg.api_key, cfg.customer_id, cfg.corpus_id, cfg.prompt_name)
|
29 |
+
|
30 |
+
cfg = st.session_state.cfg
|
31 |
+
vq = st.session_state.vq
|
32 |
+
st.set_page_config(page_title="Debate Bot", layout="wide")
|
33 |
+
|
34 |
+
|
35 |
+
# left side content
|
36 |
+
with st.sidebar:
|
37 |
+
st.markdown(f"## Welcome to Debate Bot.\n\n\n"
|
38 |
+
f"You are {cfg.human_role} '{cfg.topic}'.\n\n")
|
39 |
+
|
40 |
+
st.markdown("---")
|
41 |
+
st.markdown(
|
42 |
+
"## How this works?\n"
|
43 |
+
"This app was built with [Vectara](https://vectara.com).\n"
|
44 |
+
)
|
45 |
+
st.markdown("---")
|
46 |
+
|
47 |
+
if "messages" not in st.session_state.keys():
|
48 |
+
st.session_state.messages = [{"role": "assistant", "content": f"Please make your first statment {cfg.human_role} '{cfg.topic}'"}]
|
49 |
+
|
50 |
+
# Display chat messages
|
51 |
+
for message in st.session_state.messages:
|
52 |
+
with st.chat_message(message["role"]):
|
53 |
+
st.write(message["content"])
|
54 |
+
|
55 |
+
# User-provided prompt
|
56 |
+
if prompt := st.chat_input():
|
57 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
58 |
+
with st.chat_message("user"):
|
59 |
+
st.write(prompt)
|
60 |
+
|
61 |
+
# Generate a new response if last message is not from assistant
|
62 |
+
if st.session_state.messages[-1]["role"] != "assistant":
|
63 |
+
with st.chat_message("assistant"):
|
64 |
+
stream = generate_response(prompt, cfg.bot_role, cfg.topic)
|
65 |
+
response = st.write_stream(stream)
|
66 |
+
message = {"role": "assistant", "content": response}
|
67 |
+
st.session_state.messages.append(message)
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
launch_bot()
|
71 |
+
|
query.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from urllib.parse import quote
|
5 |
+
|
6 |
+
def extract_between_tags(text, start_tag, end_tag):
|
7 |
+
start_index = text.find(start_tag)
|
8 |
+
end_index = text.find(end_tag, start_index)
|
9 |
+
return text[start_index+len(start_tag):end_index-len(end_tag)]
|
10 |
+
|
11 |
+
class VectaraQuery():
|
12 |
+
def __init__(self, api_key: str, customer_id: str, corpus_id: str, prompt_name: str = None):
|
13 |
+
self.customer_id = customer_id
|
14 |
+
self.corpus_id = corpus_id
|
15 |
+
self.api_key = api_key
|
16 |
+
self.prompt_name = prompt_name if prompt_name else "vectara-experimental-summary-ext-2023-12-11-large"
|
17 |
+
self.conv_id = None
|
18 |
+
|
19 |
+
def get_body(self, user_response: str, role: str, topic: str):
|
20 |
+
corpora_key_list = [{
|
21 |
+
'customer_id': self.customer_id, 'corpus_id': self.corpus_id, 'lexical_interpolation_config': {'lambda': 0.025}
|
22 |
+
}]
|
23 |
+
|
24 |
+
prompt = f'''
|
25 |
+
[
|
26 |
+
{{
|
27 |
+
"role": "system",
|
28 |
+
"content": "You are a professional debate bot. You are provided with search results related to {topic}
|
29 |
+
and respond to the previous arugments made so far. Be sure to provide a thoughtful and convincing reply.
|
30 |
+
Never mention search results explicitly in your response.
|
31 |
+
Do not base your response on information or knowledge that is not in the search results.
|
32 |
+
Respond while demonstrating respect to the other party and the topic. Limit your responses to not more than 3 paragraphs."
|
33 |
+
}},
|
34 |
+
{{
|
35 |
+
"role": "user",
|
36 |
+
"content": "
|
37 |
+
#foreach ($qResult in $vectaraQueryResults)
|
38 |
+
Search result $esc.java(${{foreach.index}}+1): $esc.java(${{qResult.getText()}})
|
39 |
+
#end
|
40 |
+
"
|
41 |
+
}},
|
42 |
+
{{
|
43 |
+
"role": "user",
|
44 |
+
"content": "provide a convincing reply {role} {topic}.
|
45 |
+
Consider the search results as relevant information with which to form your response.
|
46 |
+
Do not repeat earlier arguments and make sure your new response is coherent with the previous arguments, and responsive to the last argument: {user_response}."
|
47 |
+
}}
|
48 |
+
]
|
49 |
+
'''
|
50 |
+
|
51 |
+
return {
|
52 |
+
'query': [
|
53 |
+
{
|
54 |
+
'query': "how would you respond?",
|
55 |
+
'start': 0,
|
56 |
+
'numResults': 50,
|
57 |
+
'corpusKey': corpora_key_list,
|
58 |
+
'context_config': {
|
59 |
+
'sentences_before': 2,
|
60 |
+
'sentences_after': 2,
|
61 |
+
'start_tag': "%START_SNIPPET%",
|
62 |
+
'end_tag': "%END_SNIPPET%",
|
63 |
+
},
|
64 |
+
'rerankingConfig':
|
65 |
+
{
|
66 |
+
'rerankerId': 272725718,
|
67 |
+
'mmrConfig': {
|
68 |
+
'diversityBias': 0.3
|
69 |
+
}
|
70 |
+
},
|
71 |
+
'summary': [
|
72 |
+
{
|
73 |
+
'responseLang': 'eng',
|
74 |
+
'maxSummarizedResults': 7,
|
75 |
+
'summarizerPromptName': self.prompt_name,
|
76 |
+
'promptText': prompt,
|
77 |
+
'chat': {
|
78 |
+
'store': True,
|
79 |
+
'conversationId': self.conv_id
|
80 |
+
},
|
81 |
+
}
|
82 |
+
]
|
83 |
+
}
|
84 |
+
]
|
85 |
+
}
|
86 |
+
|
87 |
+
def get_headers(self):
|
88 |
+
return {
|
89 |
+
"Content-Type": "application/json",
|
90 |
+
"Accept": "application/json",
|
91 |
+
"customer-id": self.customer_id,
|
92 |
+
"x-api-key": self.api_key,
|
93 |
+
"grpc-timeout": "60S"
|
94 |
+
}
|
95 |
+
|
96 |
+
def submit_query(self, query_str: str, role: str, topic: str):
|
97 |
+
|
98 |
+
endpoint = f"https://api.vectara.io/v1/stream-query"
|
99 |
+
body = self.get_body(query_str, role, topic)
|
100 |
+
|
101 |
+
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers(), stream=True)
|
102 |
+
if response.status_code != 200:
|
103 |
+
print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
|
104 |
+
return "Sorry, something went wrong in my brain. Please try again later."
|
105 |
+
|
106 |
+
chunks = []
|
107 |
+
accumulated_text = "" # Initialize text accumulation
|
108 |
+
pattern_max_length = 50 # Example heuristic
|
109 |
+
for line in response.iter_lines():
|
110 |
+
if line: # filter out keep-alive new lines
|
111 |
+
data = json.loads(line.decode('utf-8'))
|
112 |
+
res = data['result']
|
113 |
+
response_set = res['responseSet']
|
114 |
+
if response_set is None:
|
115 |
+
# grab next chunk and yield it as output
|
116 |
+
summary = res.get('summary', None)
|
117 |
+
if summary is None or len(summary)==0:
|
118 |
+
continue
|
119 |
+
else:
|
120 |
+
chat = summary.get('chat', None)
|
121 |
+
if chat and chat.get('status', None):
|
122 |
+
st_code = chat['status']
|
123 |
+
print(f"Chat query failed with code {st_code}")
|
124 |
+
if st_code == 'RESOURCE_EXHAUSTED':
|
125 |
+
self.conv_id = None
|
126 |
+
return 'Sorry, Vectara chat turns exceeds plan limit.'
|
127 |
+
return 'Sorry, something went wrong in my brain. Please try again later.'
|
128 |
+
conv_id = chat.get('conversationId', None) if chat else None
|
129 |
+
if conv_id:
|
130 |
+
self.conv_id = conv_id
|
131 |
+
|
132 |
+
chunk = summary['text']
|
133 |
+
accumulated_text += chunk # Append current chunk to accumulation
|
134 |
+
if len(accumulated_text) > pattern_max_length:
|
135 |
+
accumulated_text = re.sub(r"\[\d+\]", "", accumulated_text)
|
136 |
+
accumulated_text = re.sub(r"\s+\.", ".", accumulated_text)
|
137 |
+
out_chunk = accumulated_text[:-pattern_max_length]
|
138 |
+
chunks.append(out_chunk)
|
139 |
+
yield out_chunk
|
140 |
+
accumulated_text = accumulated_text[-pattern_max_length:]
|
141 |
+
|
142 |
+
if summary['done']:
|
143 |
+
break
|
144 |
+
|
145 |
+
# yield the last piece
|
146 |
+
if len(accumulated_text) > 0:
|
147 |
+
accumulated_text = re.sub(r" \[\d+\]\.", ".", accumulated_text)
|
148 |
+
chunks.append(accumulated_text)
|
149 |
+
yield accumulated_text
|
150 |
+
|
151 |
+
return ''.join(chunks)
|