debate-bot / query.py
ofermend's picture
some updates
cddbc52
import requests
import json
import re
from urllib.parse import quote
def extract_between_tags(text, start_tag, end_tag):
start_index = text.find(start_tag)
end_index = text.find(end_tag, start_index)
return text[start_index+len(start_tag):end_index-len(end_tag)]
class VectaraQuery():
def __init__(self, api_key: str, customer_id: str, corpus_id: str, prompt_name: str = None):
self.customer_id = customer_id
self.corpus_id = corpus_id
self.api_key = api_key
self.prompt_name = prompt_name if prompt_name else "vectara-experimental-summary-ext-2023-12-11-large"
self.conv_id = None
def get_body(self, user_response: str, role: str, topic: str, style: str):
corpora_key_list = [{
'customer_id': self.customer_id, 'corpus_id': self.corpus_id, 'lexical_interpolation_config': {'lambda': 0.025}
}]
user_response = user_response.replace('"', '\\"') # Escape double quotes
prompt = f'''
[
{{
"role": "system",
"content": "You are a professional debate bot.
You specialize in the {style} debate style.
You are provided with search results related to {topic}.
Follow these INSTRUCTIONS carefully:
1. Provide a thoughtful and convincing reply.
2. Do not base your response on information or knowledge that is not in the search results.
3. Respond with respect to your opponent.
4. Limit your responses to not more than 2 paragraphs."
}},
{{
"role": "assistant",
"content": "
#foreach ($qResult in $vectaraQueryResults)
Search result $esc.java(${{foreach.index}}+1): $esc.java(${{qResult.getText()}})
#end
"
}},
{{
"role": "user",
"content": "Provide a convincing response {role} {topic}, to the question '$esc.java(${{vectaraQuery}})'.
Consider the search results as relevant information with which to form your response, but do not mention the results in your response.
Consider the last argument from your opponent: '{user_response}'.
Use the {style} debate style to make your argument."
}}
]
'''
return {
'query': [
{
'query': f"What is a good argument {role} {topic}",
'start': 0,
'numResults': 50,
'corpusKey': corpora_key_list,
'context_config': {
'sentences_before': 2,
'sentences_after': 2,
'start_tag': "%START_SNIPPET%",
'end_tag': "%END_SNIPPET%",
},
'rerankingConfig':
{
'rerankerId': 272725718,
'mmrConfig': {
'diversityBias': 0.3
}
},
'summary': [
{
'responseLang': 'eng',
'maxSummarizedResults': 7,
'summarizerPromptName': self.prompt_name,
'promptText': prompt,
'chat': {
'store': True,
'conversationId': self.conv_id
},
}
]
}
]
}
def get_headers(self):
return {
"Content-Type": "application/json",
"Accept": "application/json",
"customer-id": self.customer_id,
"x-api-key": self.api_key,
"grpc-timeout": "60S"
}
def submit_query(self, query_str: str, bot_role: str, topic: str, style: str):
endpoint = f"https://api.vectara.io/v1/stream-query"
body = self.get_body(query_str, bot_role, topic, style)
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers(), stream=True)
if response.status_code != 200:
print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
return "Sorry, something went wrong in my brain. Please try again later."
chunks = []
accumulated_text = "" # Initialize text accumulation
pattern_max_length = 50 # Example heuristic
for line in response.iter_lines():
if line: # filter out keep-alive new lines
data = json.loads(line.decode('utf-8'))
res = data['result']
response_set = res['responseSet']
if response_set is None:
# grab next chunk and yield it as output
summary = res.get('summary', None)
if summary is None or len(summary)==0:
continue
else:
chat = summary.get('chat', None)
if chat and chat.get('status', None):
st_code = chat['status']
print(f"Chat query failed with code {st_code}")
if st_code == 'RESOURCE_EXHAUSTED':
self.conv_id = None
return 'Sorry, Vectara chat turns exceeds plan limit.'
return 'Sorry, something went wrong in my brain. Please try again later.'
conv_id = chat.get('conversationId', None) if chat else None
if conv_id:
self.conv_id = conv_id
chunk = summary['text']
accumulated_text += chunk # Append current chunk to accumulation
if len(accumulated_text) > pattern_max_length:
accumulated_text = re.sub(r"\[\d+\]", "", accumulated_text)
accumulated_text = re.sub(r"\s+\.", ".", accumulated_text)
out_chunk = accumulated_text[:-pattern_max_length]
chunks.append(out_chunk)
yield out_chunk
accumulated_text = accumulated_text[-pattern_max_length:]
if summary['done']:
break
# yield the last piece
if len(accumulated_text) > 0:
accumulated_text = re.sub(r" \[\d+\]\.", ".", accumulated_text)
chunks.append(accumulated_text)
yield accumulated_text
return ''.join(chunks)