david-oplatka commited on
Commit
24de7c1
1 Parent(s): 23a883f

Add Template Files

Browse files
Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. Vectara-logo.png +0 -0
  4. app.py +93 -0
  5. query.py +198 -0
  6. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "Enter Chatbot Title"
3
+ emoji: 📈
4
+ colorFrom: indigo
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.32.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: "Enter Description"
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Vectara-logo.png ADDED
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from query import VectaraQuery
3
+ import os
4
+
5
+ import streamlit as st
6
+ from PIL import Image
7
+ from dotenv import load_dotenv
8
+
9
+
10
+ load_dotenv(override=False)
11
+
12
+ def isTrue(x) -> bool:
13
+ if isinstance(x, bool):
14
+ return x
15
+ return x.strip().lower() == 'true'
16
+
17
+ def launch_bot():
18
+ def generate_response(question):
19
+ response = vq.submit_query(question)
20
+ return response
21
+
22
+ def generate_streaming_response(question):
23
+ response = vq.submit_query_streaming(question)
24
+ return response
25
+
26
+ if 'cfg' not in st.session_state:
27
+ corpus_ids = str(os.environ['corpus_ids']).split(',')
28
+ cfg = OmegaConf.create({
29
+ 'customer_id': str(os.environ['customer_id']),
30
+ 'corpus_ids': corpus_ids,
31
+ 'api_key': str(os.environ['api_key']),
32
+ 'title': os.environ['title'],
33
+ 'description': os.environ['description'],
34
+ 'source_data_desc': os.environ['source_data_desc'],
35
+ 'streaming': isTrue(os.environ.get('streaming', False)),
36
+ 'prompt_name': os.environ.get('prompt_name', None)
37
+ })
38
+ st.session_state.cfg = cfg
39
+ st.session_state.vq = VectaraQuery(cfg.api_key, cfg.customer_id, cfg.corpus_ids, cfg.prompt_name)
40
+
41
+ cfg = st.session_state.cfg
42
+ vq = st.session_state.vq
43
+ st.set_page_config(page_title=cfg.title, layout="wide")
44
+
45
+ # left side content
46
+ with st.sidebar:
47
+ image = Image.open('Vectara-logo.png')
48
+ st.markdown(f"## Welcome to {cfg.title}\n\n"
49
+ f"This demo uses Retrieval Augmented Generation to ask questions about {cfg.source_data_desc}\n\n")
50
+
51
+ st.markdown("---")
52
+ st.markdown(
53
+ "## How this works?\n"
54
+ "This app was built with [Vectara](https://vectara.com).\n"
55
+ "Vectara's [Indexing API](https://docs.vectara.com/docs/api-reference/indexing-apis/indexing) was used to ingest the data into a Vectara corpus (or index).\n\n"
56
+ "This app uses Vectara [Chat API](https://docs.vectara.com/docs/console-ui/vectara-chat-overview) to query the corpus and present the results to you, answering your question.\n\n"
57
+ )
58
+ st.markdown("---")
59
+ st.image(image, width=250)
60
+
61
+ st.markdown(f"<center> <h2> Vectara chat demo: {cfg.title} </h2> </center>", unsafe_allow_html=True)
62
+ st.markdown(f"<center> <h4> {cfg.description} <h4> </center>", unsafe_allow_html=True)
63
+
64
+ if "messages" not in st.session_state.keys():
65
+ st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
66
+
67
+ # Display chat messages
68
+ for message in st.session_state.messages:
69
+ with st.chat_message(message["role"]):
70
+ st.write(message["content"])
71
+
72
+ # User-provided prompt
73
+ if prompt := st.chat_input():
74
+ st.session_state.messages.append({"role": "user", "content": prompt})
75
+ with st.chat_message("user"):
76
+ st.write(prompt)
77
+
78
+ # Generate a new response if last message is not from assistant
79
+ if st.session_state.messages[-1]["role"] != "assistant":
80
+ with st.chat_message("assistant"):
81
+ if cfg.streaming:
82
+ stream = generate_streaming_response(prompt)
83
+ response = st.write_stream(stream)
84
+ else:
85
+ with st.spinner("Thinking..."):
86
+ response = generate_response(prompt)
87
+ st.write(response)
88
+ message = {"role": "assistant", "content": response}
89
+ st.session_state.messages.append(message)
90
+
91
+ if __name__ == "__main__":
92
+ launch_bot()
93
+
query.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import re
4
+ from urllib.parse import quote
5
+
6
+ def extract_between_tags(text, start_tag, end_tag):
7
+ start_index = text.find(start_tag)
8
+ end_index = text.find(end_tag, start_index)
9
+ return text[start_index+len(start_tag):end_index-len(end_tag)]
10
+
11
+ class CitationNormalizer():
12
+
13
+ def __init__(self, responses, docs):
14
+ self.docs = docs
15
+ self.responses = responses
16
+ self.refs = []
17
+
18
+ def normalize_citations(self, summary):
19
+ start_tag = "%START_SNIPPET%"
20
+ end_tag = "%END_SNIPPET%"
21
+
22
+ # find all references in the summary
23
+ pattern = r'\[\d{1,2}\]'
24
+ matches = [match.span() for match in re.finditer(pattern, summary)]
25
+
26
+ # figure out unique list of references
27
+ for match in matches:
28
+ start, end = match
29
+ response_num = int(summary[start+1:end-1])
30
+ doc_num = self.responses[response_num-1]['documentIndex']
31
+ metadata = {item['name']: item['value'] for item in self.docs[doc_num]['metadata']}
32
+ text = extract_between_tags(self.responses[response_num-1]['text'], start_tag, end_tag)
33
+ if 'url' in metadata.keys():
34
+ url = f"{metadata['url']}#:~:text={quote(text)}"
35
+ if url not in self.refs:
36
+ self.refs.append(url)
37
+
38
+ # replace references with markdown links
39
+ refs_dict = {url:(inx+1) for inx,url in enumerate(self.refs)}
40
+ for match in reversed(matches):
41
+ start, end = match
42
+ response_num = int(summary[start+1:end-1])
43
+ doc_num = self.responses[response_num-1]['documentIndex']
44
+ metadata = {item['name']: item['value'] for item in self.docs[doc_num]['metadata']}
45
+ text = extract_between_tags(self.responses[response_num-1]['text'], start_tag, end_tag)
46
+ if 'url' in metadata.keys():
47
+ url = f"{metadata['url']}#:~:text={quote(text)}"
48
+ citation_inx = refs_dict[url]
49
+ summary = summary[:start] + f'[\[{citation_inx}\]]({url})' + summary[end:]
50
+ else:
51
+ summary = summary[:start] + summary[end:]
52
+
53
+ return summary
54
+
55
+ class VectaraQuery():
56
+ def __init__(self, api_key: str, customer_id: str, corpus_ids: list[str], prompt_name: str = None):
57
+ self.customer_id = customer_id
58
+ self.corpus_ids = corpus_ids
59
+ self.api_key = api_key
60
+ self.prompt_name = prompt_name if prompt_name else "vectara-experimental-summary-ext-2023-12-11-sml"
61
+ self.conv_id = None
62
+
63
+ def get_body(self, query_str: str):
64
+ corpora_key_list = [{
65
+ 'customer_id': self.customer_id, 'corpus_id': corpus_id, 'lexical_interpolation_config': {'lambda': 0.025}
66
+ } for corpus_id in self.corpus_ids
67
+ ]
68
+
69
+ return {
70
+ 'query': [
71
+ {
72
+ 'query': query_str,
73
+ 'start': 0,
74
+ 'numResults': 50,
75
+ 'corpusKey': corpora_key_list,
76
+ 'context_config': {
77
+ 'sentences_before': 2,
78
+ 'sentences_after': 2,
79
+ 'start_tag': "%START_SNIPPET%",
80
+ 'end_tag': "%END_SNIPPET%",
81
+ },
82
+ 'rerankingConfig':
83
+ {
84
+ 'rerankerId': 272725718,
85
+ 'mmrConfig': {
86
+ 'diversityBias': 0.3
87
+ }
88
+ },
89
+ 'summary': [
90
+ {
91
+ 'responseLang': 'eng',
92
+ 'maxSummarizedResults': 5,
93
+ 'summarizerPromptName': self.prompt_name,
94
+ 'chat': {
95
+ 'store': True,
96
+ 'conversationId': self.conv_id
97
+ },
98
+ }
99
+ ]
100
+ }
101
+ ]
102
+ }
103
+
104
+ def get_headers(self):
105
+ return {
106
+ "Content-Type": "application/json",
107
+ "Accept": "application/json",
108
+ "customer-id": self.customer_id,
109
+ "x-api-key": self.api_key,
110
+ "grpc-timeout": "60S"
111
+ }
112
+
113
+ def submit_query(self, query_str: str):
114
+
115
+ endpoint = f"https://api.vectara.io/v1/query"
116
+ body = self.get_body(query_str)
117
+
118
+ response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())
119
+ if response.status_code != 200:
120
+ print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
121
+ return "Sorry, something went wrong in my brain. Please try again later."
122
+
123
+ res = response.json()
124
+
125
+ top_k = 10
126
+ summary = res['responseSet'][0]['summary'][0]['text']
127
+ responses = res['responseSet'][0]['response'][:top_k]
128
+ docs = res['responseSet'][0]['document']
129
+ chat = res['responseSet'][0]['summary'][0].get('chat', None)
130
+
131
+ if chat and chat['status'] is not None:
132
+ st_code = chat['status']
133
+ print(f"Chat query failed with code {st_code}")
134
+ if st_code == 'RESOURCE_EXHAUSTED':
135
+ self.conv_id = None
136
+ return 'Sorry, Vectara chat turns exceeds plan limit.'
137
+ return 'Sorry, something went wrong in my brain. Please try again later.'
138
+
139
+ self.conv_id = chat['conversationId'] if chat else None
140
+ summary = CitationNormalizer(responses, docs).normalize_citations(summary)
141
+ return summary
142
+
143
+ def submit_query_streaming(self, query_str: str):
144
+
145
+ endpoint = f"https://api.vectara.io/v1/stream-query"
146
+ body = self.get_body(query_str)
147
+
148
+ response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers(), stream=True)
149
+ if response.status_code != 200:
150
+ print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
151
+ return "Sorry, something went wrong in my brain. Please try again later."
152
+
153
+ chunks = []
154
+ accumulated_text = "" # Initialize text accumulation
155
+ pattern_max_length = 50 # Example heuristic
156
+ for line in response.iter_lines():
157
+ if line: # filter out keep-alive new lines
158
+ data = json.loads(line.decode('utf-8'))
159
+ res = data['result']
160
+ response_set = res['responseSet']
161
+ if response_set is None:
162
+ # grab next chunk and yield it as output
163
+ summary = res.get('summary', None)
164
+ if summary is None or len(summary)==0:
165
+ continue
166
+ else:
167
+ chat = summary.get('chat', None)
168
+ if chat and chat.get('status', None):
169
+ st_code = chat['status']
170
+ print(f"Chat query failed with code {st_code}")
171
+ if st_code == 'RESOURCE_EXHAUSTED':
172
+ self.conv_id = None
173
+ return 'Sorry, Vectara chat turns exceeds plan limit.'
174
+ return 'Sorry, something went wrong in my brain. Please try again later.'
175
+ conv_id = chat.get('conversationId', None) if chat else None
176
+ if conv_id:
177
+ self.conv_id = conv_id
178
+
179
+ chunk = summary['text']
180
+ accumulated_text += chunk # Append current chunk to accumulation
181
+ if len(accumulated_text) > pattern_max_length:
182
+ accumulated_text = re.sub(r"\[\d+\]", "", accumulated_text)
183
+ accumulated_text = re.sub(r"\s+\.", ".", accumulated_text)
184
+ out_chunk = accumulated_text[:-pattern_max_length]
185
+ chunks.append(out_chunk)
186
+ yield out_chunk
187
+ accumulated_text = accumulated_text[-pattern_max_length:]
188
+
189
+ if summary['done']:
190
+ break
191
+
192
+ # yield the last piece
193
+ if len(accumulated_text) > 0:
194
+ accumulated_text = re.sub(r" \[\d+\]\.", ".", accumulated_text)
195
+ chunks.append(accumulated_text)
196
+ yield accumulated_text
197
+
198
+ return ''.join(chunks)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ requests_to_curl==1.1.0
2
+ toml==0.10.2
3
+ omegaconf==2.3.0
4
+ syrupy==4.0.8