Spaces:
Running
Running
from setup_code import * # This imports everything from setup_code.py | |
class Query_Agent: | |
def __init__(self, pinecone_index, pinecone_index_python, openai_client) -> None: | |
# TODO: Initialize the Query_Agent agent | |
self.pinecone_index = pinecone_index | |
self.pinecone_index_python = pinecone_index_python | |
self.openai_client = openai_client | |
self.query_embedding = None | |
self.codbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") | |
self.codebert_model = AutoModel.from_pretrained("microsoft/codebert-base") | |
def get_codebert_embedding(self, code: str): | |
inputs = self.codbert_tokenizer(code, return_tensors="pt", max_length=512, truncation=True) | |
outputs = self.codebert_model(**inputs) | |
cb_embedding = outputs.last_hidden_state.mean(dim=1) # A simple way to pool the embeddings | |
cb_embedding = cb_embedding.detach().numpy() | |
cb_embedding = cb_embedding.tolist() | |
cb_embedding = cb_embedding[0] | |
return cb_embedding | |
def get_openai_embedding(self, text, model="text-embedding-ada-002"): | |
text = text.replace("\n", " ") | |
return self.openai_client.embeddings.create(input=[text], model=model).data[0].embedding | |
def query_vector_store(self, query, query_topic: str, index=None, k=5) -> str: | |
if index == None: | |
index = self.pinecone_index | |
if query_topic == 'ml': | |
self.query_embedding = self.get_openai_embedding(query) | |
elif query_topic == 'python': | |
index = self.pinecone_index_python | |
self.query_embedding = self.get_codebert_embedding(query) | |
def get_namespace(index): | |
stat = index.describe_index_stats() | |
stat_dict_key = stat['namespaces'].keys() | |
stat_dict_key_list = list(stat_dict_key) | |
first_key = stat_dict_key_list[0] | |
return first_key | |
ns = get_namespace(index) | |
if query_topic == 'ml': | |
matches_text = get_top_k_text(index.query( | |
namespace=ns, | |
top_k=k, | |
vector=self.query_embedding, | |
include_values=True, | |
include_metadata=True | |
) | |
) | |
elif query_topic == 'python': | |
matches_text = get_top_filename(index.query( | |
namespace=ns, | |
top_k=k, | |
vector=self.query_embedding, | |
include_values=True, | |
include_metadata=True | |
) | |
) | |
return matches_text | |
def process_query_response(self, head_agent, user_query, query_topic): | |
# Retrieve the history related to the query_topic | |
conversation = [] | |
index = head_agent.pinecone_index | |
if query_topic == "ml": | |
conversation = Head_Agent.get_history_about('ml') | |
elif query_topic == 'python': | |
conversation = Head_Agent.get_history_about('python') | |
index = head_agent.pinecone_index_python | |
# get matches from Query_Agent, which uses Pinecone | |
user_query_plus_conversation = f"The current query is: {user_query}" | |
if len(conversation) > 0: | |
conversation_text = "\n".join(conversation) | |
user_query_plus_conversation += f'The current conversation is: {conversation_text}' | |
## self.query_embedding is set here | |
matches_text = self.query_vector_store(user_query_plus_conversation, query_topic, index) | |
if head_agent.relevant_documents_agent.is_relevant(matches_text, user_query_plus_conversation) or contains_py_filename(matches_text): | |
response = head_agent.answering_agent.generate_response(user_query, matches_text, conversation, head_agent.selected_mode) | |
else: | |
prompt_for_gpt = f"Return a response to this query: {user_query} in the context of this conversation: {conversation}. Please use language appropriate for a {head_agent.selected_mode}." | |
response = get_completion(head_agent.openai_client, prompt_for_gpt) | |
response = "[EXTERNAL] " + response | |
return response | |
class Answering_Agent: | |
def __init__(self, openai_client) -> None: | |
self.client = openai_client | |
def generate_response(self, query, docs, conv_history, selected_mode): | |
prompt_for_gpt = f"Based on this text in angle brackets: <{docs}>, please summarize a response to this query: {query} in the context of this conversation: {conv_history}. Please use language appropriate for a {selected_mode}." | |
return get_completion(self.client, prompt_for_gpt) | |
def generate_response_topic(self, topic_desc, topic_text, conv_history, selected_mode): | |
prompt_for_gpt = f"Please return a summary response on this topic: {topic_desc} using this text as best as possible {topic_text} in the context of this {conv_history}. Please use language appropriate for a {selected_mode}." | |
return get_completion(self.client, prompt_for_gpt) | |
def generate_image(self, text): | |
if DEBUG: | |
return None, "" | |
dall_e_prompt_from_gpt = f"Based on this text, repeated here in double square brackets for your reference: [[{text}]], please generate a simple caption that I can use with dall-e to generate an instructional image." | |
dall_e_text = get_completion(self.client, dall_e_prompt_from_gpt) | |
# Write open_ai text | |
with open("dall_e_prompts.txt", "a") as f: | |
f.write(f"{dall_e_text}\n\n") | |
# get image from dall-e | |
image = Head_Agent.text_to_image(self.client, dall_e_text) | |
# once u have get a caption from GPT | |
image_caption_prompt = f"This text in double square brackets is used to prompt dall-e: [[{dall_e_text}]]. Please generate a simple caption that I can use to display with the image dall-e will create. Only return that caption." | |
image_caption = get_completion(self.client, image_caption_prompt) | |
#st.write(f"image_caption_prompt): {image_caption_prompt}") | |
return (image, image_caption) | |
class Concepts_Agent: | |
def __init__(self): | |
self._df = pd.read_csv("/content/gdrive/MyDrive/LLM_Winter2024/concepts_final.csv") | |
#self.topic_matrix = [[0] * 5 for _ in range(12)] | |
def increase_cell(self, i, j): | |
st.session_state.topic_matrix[i][j] += + 1 | |
def display_topic_matrix(self): | |
headers = [f"Topic {i}" for i in range(1, 6)] | |
row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)] | |
topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers) | |
st.table(topic_df) | |
st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}") | |
def display_topic_matrix(self): | |
headers = [f"Topic {i}" for i in range(1, 6)] | |
row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)] | |
topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers) | |
st.table(topic_df) | |
st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}") | |
def display_topic_matrix_star(self): | |
headers = [f"Topic {i}" for i in range(1, 6)] | |
row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)] | |
# Replace 1 with the Unicode star symbol | |
topic_matrix_star = [[chr(9733) if val == 1 else val for val in row] for row in st.session_state.topic_matrix] | |
topic_df = pd.DataFrame(topic_matrix_star, row_indices, headers) | |
st.table(topic_df) | |
st.write(f"Total Topics covered: {sum(sum(row) for row in st.session_state.topic_matrix)}") | |
def display_topic_matrix_as_image(self): | |
headers = [f"Topic {i}" for i in range(1, 6)] | |
row_indices = [f"{self._df['concept'][i-1]}" for i in range(1, 13)] | |
topic_df = pd.DataFrame(st.session_state.topic_matrix, row_indices, headers) | |
df_html = topic_df.to_html(index=False) | |
# Create an image of the HTML table | |
image = Image.new("RGB", (800, 600), color="white") # Define image size | |
draw = ImageDraw.Draw(image) | |
draw.text((10, 10), df_html, fill="black") # Position of the table in the image | |
# Save the image to a byte stream | |
image_byte_array = io.BytesIO() | |
image.save(image_byte_array, format="PNG") | |
image_byte_array.seek(0) | |
# Now you can use the image_byte_array in Streamlit as an image | |
st.image(image_byte_array, caption="DataFrame as Image") | |
return image_byte_array | |
# for each query_embedding, we will look through the df of concepts | |
# we'll do a cosine_similarity of that query_embedding with each of the embeddings for each concept | |
def find_top_concept_index(self, query_embedding): | |
top_sim = 0 | |
top_concept_index = 0 | |
for index, row in self._df.iterrows(): | |
float_array = np.array(ast.literal_eval(row['embedding'])).reshape(1, -1) | |
qe_array = np.array(query_embedding).reshape(1, -1) | |
sim = cosine_similarity(float_array, qe_array) | |
if sim[0][0] > top_sim: | |
top_sim = sim[0][0] | |
top_concept_index = index | |
return top_concept_index | |
def get_top_k_text_list(self, matches, k): | |
text_list = [] | |
for i in range(0, k): | |
text_list.append(matches.get('matches')[i]['metadata']['text']) | |
return text_list | |
def write_to_file(self, filename): | |
self._df.to_csv(filename, index=False) # Setting index=False to avoid writing row indices | |
class Head_Agent: | |
def __init__(self, openai_key, pinecone_key) -> None: | |
# TODO: Initialize the Head_Agent | |
self.openai_key = openai_key | |
self.pinecone_key = pinecone_key | |
self.selected_mode = "" | |
self.openai_client = OpenAI(api_key=self.openai_key) | |
self.pc = Pinecone(api_key=self.pinecone_key) | |
self.pinecone_index = self.pc.Index("index-600") | |
self.pinecone_index_python = self.pc.Index("index-python-files") | |
self.query_embedding_local = None | |
self.setup_sub_agents() | |
def setup_sub_agents(self): | |
self.classify_agent = Classify_Agent(self.openai_client) | |
self.query_agent = Query_Agent(self.pinecone_index, self.pinecone_index_python, self.openai_client) # took away embeddings argument since not used | |
self.answering_agent = Answering_Agent(self.openai_client) | |
self.relevant_documents_agent = Relevant_Documents_Agent(self.openai_client) | |
self.ca = Concepts_Agent() | |
def get_conversation(): | |
# ... (code for getting conversation history) | |
return Head_Agent.get_history_about() | |
def get_history_about(topic=None): | |
history = [] | |
for message in st.session_state.messages: | |
role = message["role"] | |
content = message["content"] | |
if topic == None: | |
if role == "user": | |
history.append(f"{content} ") | |
else: | |
if message["topic"] == topic: | |
history.append(f"{content} ") | |
# st.write(f"user history in get_conversation is {history}") | |
if history != None: | |
history = history[-2:] | |
return history | |
def text_to_image(openai_client, text): | |
model = "dall-e-3" | |
size = "512x512" | |
with st.spinner("Generating ..."): | |
response = openai_client.images.generate( | |
model=model, | |
prompt = text, | |
n=1, | |
size="1024x1024" | |
) | |
image_url = response.data[0].url | |
with urllib.request.urlopen(image_url) as image_url: | |
img = Image.open(BytesIO(image_url.read())) | |
return img | |
def get_default_value(self, variable): | |
if variable == "openai_model": return "gpt-3.5-turbo" | |
elif variable == "messages": return [] | |
elif variable == "stage": return 0 | |
elif variable == "query_embedding": return None | |
elif variable == "topic_matrix": return [[0] * 5 for _ in range(12)] | |
else: | |
st.write(f"Error: get_default_value, variable not defined: {variable}") | |
return None | |
def initialize_session_state(self): | |
session_state_variables = ["openai_model", "messages", "stage", "query_embedding", "topic_matrix"] | |
for variable in session_state_variables: | |
if variable not in st.session_state: | |
st.session_state[variable] = self.get_default_value(variable) | |
def display_selection_options(self): | |
modes = ['college student', 'middle school student', '1st grade student', 'high school student', 'grad student'] | |
self.selected_mode = st.selectbox("Select your education level:", modes) | |
def display_chat_messages(self): | |
# Display existing chat messages | |
for message in st.session_state.messages: | |
if message["role"] == "assistant": | |
with st.chat_message("assistant"): | |
st.write(message["content"]) | |
if message['image'] != None: | |
st.image(message['image']) | |
else: | |
with st.chat_message("user"): | |
st.write(message["content"]) | |
def main_loop(self): | |
st.title("Machine Learning Text Guide Chatbot") | |
self.initialize_session_state() | |
self.display_selection_options() | |
self.display_chat_messages() | |
### Wait for user input ### | |
if user_query := st.chat_input("What would you like to chat about?"): | |
with st.chat_message("user"): st.write(user_query) | |
with st.chat_message("assistant"): | |
response = ""; topic = None; image = None; caption = ""; st.session_state.stage = 0 | |
# Get the current conversation with new user query to check for users' intention | |
conversation = self.get_conversation() | |
user_query_plus_conversation = f"The current query is: {user_query}. The current conversation is: {conversation}" | |
classify_query = self.classify_agent.classify_query(user_query_plus_conversation) | |
if classify_query == general_greeting_num: | |
response = "How can I assist you today?" | |
elif classify_query == general_question_num: | |
response = "Please ask a question about Machine Learning or Python Code." | |
elif classify_query == obnoxious_num: | |
response = "Please dont be obnoxious." | |
elif classify_query == progress_num: | |
self.ca.display_topic_matrix_star() | |
elif classify_query == default_num: | |
response = "I'm not sure how to respond to that." | |
elif classify_query == machine_learning_num: | |
response = self.query_agent.process_query_response(self, user_query, 'ml') | |
st.session_state.query_embedding = self.query_agent.get_openai_embedding(user_query) | |
image, caption = self.answering_agent.generate_image(response) | |
topic = "ml" | |
st.session_state.stage = 1 | |
elif classify_query == python_code_num: | |
response = self.query_agent.process_query_response(self, user_query, 'python') | |
image, caption = self.answering_agent.generate_image(response) | |
topic = "python" | |
st.session_state.stage = 0 | |
else: | |
response = "I'm not sure how to respond to that." | |
# ... (get AI response and display it) | |
st.write(response) | |
if image and caption != "": st.image(image, caption) | |
st.session_state.messages.append({"role": "user", "content": user_query, "topic": topic, "image": None}) | |
st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image}) | |
if st.session_state.stage == 1: ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### | |
# it looks like after we hit st.button, we go back to the top of the st.session_state.stage == 1 loop, and we lose the query_embedding_local | |
# we use st.session_state.query_embedding to get the concept index | |
top_concept_index = self.ca.find_top_concept_index(st.session_state.query_embedding) | |
concept_name = self.ca._df['concept'][top_concept_index] | |
st.write(f"Your question is associated to the Fundamental Concept in Machine Learning: {concept_name}.\n\n") | |
st.write(f"Here are some topics you can explore to help you learn about {concept_name}, pick one.") | |
response = ""; image = None; topic = "" | |
topic0_desc = self.ca._df['topic_0_desc'][top_concept_index] | |
topic1_desc = self.ca._df['topic_1_desc'][top_concept_index] | |
topic2_desc = self.ca._df['topic_2_desc'][top_concept_index] | |
topic3_desc = self.ca._df['topic_3_desc'][top_concept_index] | |
topic4_desc = self.ca._df['topic_4_desc'][top_concept_index] | |
matrix_row = st.session_state.topic_matrix[top_concept_index] | |
if (matrix_row[0] == 0 and st.session_state.stage): | |
if st.button(topic0_desc): process_button_click(self, 0, topic0_desc, top_concept_index) | |
if (matrix_row[1] == 0 and st.session_state.stage): | |
if st.button(topic1_desc): process_button_click(self, 1, topic1_desc, top_concept_index) | |
if (matrix_row[2] == 0 and st.session_state.stage): | |
if st.button(topic2_desc): process_button_click(self, 2, topic2_desc, top_concept_index) | |
if (matrix_row[3] == 0 and st.session_state.stage): | |
if st.button(topic3_desc): process_button_click(self, 3, topic3_desc, top_concept_index) | |
if (matrix_row[4] == 0 and st.session_state.stage): | |
if st.button(topic4_desc): process_button_click(self, 4, topic4_desc, top_concept_index) | |
def process_button_click(head, button_index, topic_desc, top_concept_index): | |
with st.chat_message("user"): st.write(topic_desc) | |
# we then assign to st.session_state.query_embedding the embedding for the topic_desc | |
st.session_state.query_embedding = head.query_agent.get_openai_embedding(topic_desc) | |
topic_text_index = 'topic_' + str(button_index) | |
topic_text = head.ca._df[topic_text_index][top_concept_index] | |
response = head.answering_agent.generate_response_topic(topic_desc, topic_text, head.get_conversation(), head.selected_mode) | |
image, caption = head.answering_agent.generate_image(topic_text) | |
topic = topic_desc | |
st.session_state.topic_matrix[top_concept_index][button_index] += 1 | |
st.write(response) | |
if image and caption != "": st.image(image, caption) | |
# ... (add response & image to message) | |
st.session_state.messages.append({"role": "user", "content": topic_desc, "topic": "ml", "image": None}) | |
st.session_state.messages.append({"role": "assistant", "content": response, "topic": topic, "image": image}) | |
st.session_state.stage = 0 | |
if __name__ == "__main__": | |
head_agent = Head_Agent(OPENAI_KEY, pc_apikey) | |
DEBUG = False | |
head_agent.main_loop() | |
#main() | |