import pandas as pd
import numpy as np
from openai import OpenAI
from sentence_transformers import util, SentenceTransformer
import torch
import time
from time import perf_counter as timer
from datetime import datetime
import textwrap
import json
import gradio as gr
client = OpenAI()
# Import saved file and view
embeddings_df_save_path = "./SRD_embeddings.csv"
print("Loading embeddings.csv")
text_chunks_and_embedding_df_load = pd.read_csv(embeddings_df_save_path)
print("Embedding file loaded")
embedding_model_path = "BAAI/bge-m3"
print("Loading embedding model")
embedding_model = SentenceTransformer(model_name_or_path=embedding_model_path,
device='cpu') # choose the device to load the model to
# Convert the stringified embeddings back to numpy arrays
text_chunks_and_embedding_df_load['embedding'] = text_chunks_and_embedding_df_load['embedding_str'].apply(lambda x: np.array(json.loads(x)))
# Convert texts and embedding df to list of dicts
pages_and_chunks = text_chunks_and_embedding_df_load.to_dict(orient="records")
# Convert embeddings to torch tensor and send to device (note: NumPy arrays are float64, torch tensors are float32 by default)
embeddings = torch.tensor(np.array(text_chunks_and_embedding_df_load["embedding"].tolist()), dtype=torch.float32).to('cpu')
# Define helper function to print wrapped text
def print_wrapped(text, wrap_length=80):
wrapped_text = textwrap.fill(text, wrap_length)
def hybrid_estimate_tokens(text: str)-> float:
# Part 1: Estimate based on spaces and punctuation
estimated_words = text.count(' ') + 1 # Counting words by spaces
punctuation_count = sum(1 for char in text if char in ',.!?;:') # Counting punctuation as potential separate tokens
estimate1 = estimated_words + punctuation_count
# Part 2: Estimate based on total characters divided by average token length
average_token_length = 4
total_characters = len(text)
estimate2 = (total_characters // average_token_length) + punctuation_count
# Average the two estimates
estimated_tokens = (estimate1 + estimate2) / 2
return estimated_tokens
def retrieve_relevant_resources(query: str,
embeddings: torch.tensor,
model: SentenceTransformer=embedding_model,
n_resources_to_return: int=4,
print_time: bool=True):
Embeds a query with model and returns top k scores and indices from embeddings.
# Embed the query
query_embedding = model.encode(query,
# Get dot product scores on embeddings
start_time = timer()
dot_scores = util.dot_score(query_embedding, embeddings)[0]
end_time = timer()
if print_time:
print(f"[INFO] Time taken to get scores on {len(embeddings)} embeddings: {end_time-start_time:.5f} seconds.")
scores, indices = torch.topk(input=dot_scores,
return scores, indices
def print_top_results_and_scores(query: str,
embeddings: torch.tensor,
pages_and_chunks: list[dict]=pages_and_chunks,
n_resources_to_return: int=5):
Takes a query, retrieves most relevant resources and prints them out in descending order.
Note: Requires pages_and_chunks to be formatted in a specific way (see above for reference).
scores, indices = retrieve_relevant_resources(query=query,
print(f"Query: {query}\n")
# Loop through zipped together scores and indicies
for score, index in zip(scores, indices):
print(f"Score: {score:.4f}")
print(f"Token Count : {pages_and_chunks[index]['chunk_token_count']}")
# Print relevant sentence chunk (since the scores are in descending order, the most relevant chunk will be first)
# Print the page number too so we can reference the textbook further and check the results
print(f"File of Origin: {pages_and_chunks[index]['file_path']}")
return scores, indices
def prompt_formatter(query: str,
context_items: list[dict]) -> str:
# Join context items into one dotted paragraph
# print(context_items[0])
# Alternate print method
# print("\n".join([item["file_path"] + "\n" + str(item['chunk_token_count']) + "\n" + item["sentence_chunk"] for item in context_items]))
context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])
# Create a base prompt with examples to help the model
# Note: this is very customizable, I've chosen to use 3 examples of the answer style we'd like.
# We could also write this in a txt file and import it in if we wanted.
base_prompt = """Now use the following context items to answer the user query: {context}
User query: {query}
# Update base prompt with context items and query
return base_prompt.format(context=context, query=query)
system_prompt = """You are a game design expert specializing in Dungeons & Dragons 5e, answering beginner questions with descriptive, clear responses. Provide a story example. Avoid extraneous details and focus on direct answers. Use the examples provided as a guide for style and brevity. When responding:
1. Identify the key point of the query.
2. Provide a straightforward answer, omitting the thought process.
3. Avoid additional advice or extended explanations.
4. Answer in an informative manner, aiding the user's understanding without overwhelming them.
6. End with a line break and "What else can I help with?"
Refer to these examples for your response style:
Example 1:
Query: How do I determine what my magic ring does in D&D?
Answer: To learn what your magic ring does, use the Identify spell, take a short rest to study it, or consult a knowledgeable character. Once known, follow the item's instructions to activate and use its powers.
Example 2:
Query: What's the effect of the spell fireball?
Answer: Fireball is a 3rd-level spell creating a 20-foot-radius sphere of fire, dealing 8d6 fire damage (half on a successful Dexterity save) to creatures within. It ignites flammable objects not worn or carried.
Example 3:
Query: How do spell slots work for a wizard?
Answer: Spell slots represent your capacity to cast spells. You use a slot of equal or higher level to cast a spell, and you regain all slots after a long rest. You don't lose prepared spells after casting; they can be reused as long as you have available slots.
Use the context provided to answer the user's query concisely. """
with gr.Blocks() as RulesLawyer:
message_state = gr.State()
chatbot_state = gr.State([])
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
def store_message(message):
return message
def respond(message, chat_history):
print(f"User Input : {message}")
print(f"Chat History: {chat_history}")
print(f"""Token Estimate: {hybrid_estimate_tokens(f"{message} {chat_history}")}""")
# Get relevant resources
scores, indices = print_top_results_and_scores(query=message,
# Create a list of context items
context_items = [pages_and_chunks[i] for i in indices]
# Format prompt with context items
prompt = prompt_formatter(query=f"Chat History : {chat_history} + {message}",
bot_message =
"role": "user",
"content": f"{system_prompt} {prompt}"
chat_history.append((message, bot_message.choices[0].message.content))
print(f"Response : {bot_message.choices[0].message.content}")
return "", chat_history
msg.change(store_message, inputs = [msg], outputs = [message_state])
chatbot.change(store_message, [chatbot], [chatbot_state])
msg.submit(respond, [message_state, chatbot_state], [msg, chatbot])
if __name__ == "__main__":