RAG chatbot using llama3
Definition
First let's define what's RAG: Retrieval-Augmented Generation. It's a technique used in natural language processing (NLP) to improve the performance of language models by incorporating external knowledge sources, such as databases or search engines. The basic idea is to retrieve relevant information from an external source based on the input query
Tools
for this blog, we require the following libraries:
pip install -q datasets sentence-transformers faiss-cpu accelerate
Embed the original dataset
This is a really slow process so we advise you to select a GPU
this is a necessary step and it is by far the slowest one on our list, we recommend you embed your dataset and save it/push it to the hub to avoid doing it every time
let's start by loading our original dataset
from datasets import load_dataset
dataset = load_dataset("not-lain/wikipedia")
dataset # Let's checkout our dataset
>>> DatasetDict({
train: Dataset({
features: ['id', 'url', 'title', 'text'],
num_rows: 3000
})
})
then we load our embedding model, I'm going to go with mixedbread-ai/mxbai-embed-large-v1
from sentence_transformers import SentenceTransformer
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
now let's embed our dataset
def embed(batch):
"""
adds a column to the dataset called 'embeddings'
"""
# or you can combine multiple columns here
# For example the title and the text
information = batch["text"]
return {"embeddings" : ST.encode(information)}
dataset = dataset.map(embed,batched=True,batch_size=16)
It is advised that you save your dataset to avoid going through this step each time
to keep the original dataset for all users intact, I will push the embedded one to a new branch, this can be easily made using the revision
parameter
dataset.push_to_hub("not-lain/wikipedia", revision="embedded")
Search through the dataset
you call your dataset from the hub
from datasets import load_dataset
dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
then add the faiss index using the embeddings
column that we created
data = dataset["train"]
data = data.add_faiss_index("embeddings")
let's define a search function
def search(query: str, k: int = 3 ):
"""a function that embeds a new query and returns the most probable results"""
embedded_query = ST.encode(query) # embed new query
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
"embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
k=k # get only top k results
)
return scores, retrieved_examples
# search for word anarchy and get the best 4 matching values from the dataset
scores , result = search("anarchy", 4 )
result['title']
>>> ['Anarchism', 'Anarcho-capitalism', 'Community', 'Capitalism']
print(result["text"][0])
>>>"Anarchism is a political philosophy and movement that is skeptical of all justifications for authority and (...)"
RAG chatbot
the following work is a draft of what an RAG chatbot might look like :
embed (only once)
│
└── new query
│
└── retrieve
│
└─── format prompt
│
└── GenAI
│
└── generate response
Now let's strap everything together in a new session after embedding:
pip install -q datasets sentence-transformers faiss-cpu accelerate bitsandbytes
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
def search(query: str, k: int = 3 ):
"""a function that embeds a new query and returns the most probable results"""
embedded_query = ST.encode(query) # embed new query
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
"embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
k=k # get only top k results
)
return scores, retrieved_examples
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# use quantization to lower GPU usage
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
we recommend you setup a system prompt to guide the LLM in generating responses.
SYS_PROMPT = """You are an assistant for answering questions.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer."""
def format_prompt(prompt,retrieved_documents,k):
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Question:{prompt}\nContext:"
for idx in range(k) :
PROMPT+= f"{retrieved_documents['text'][idx]}\n"
return PROMPT
def generate(formatted_prompt):
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
# tell the model to generate
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=1024,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(response, skip_special_tokens=True)
def rag_chatbot(prompt:str,k:int=2):
scores , retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
return generate(formatted_prompt)
rag_chatbot("what's anarchy ?", k = 2)
>>>"So, anarchism is a political philosophy that questions the need for authority and hierarchy, and (...)"
Demo
A demo application to try out the application can be found here
Dedication
in loving memory of Rayner V. Giuret, a friend, a brother, and an idol to all of us at LowRes.
Your legacy lives on in our hearts and minds. Thanks for everything.
Rest in peace, Rayner.