File size: 4,393 Bytes
39451f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import transformers
from transformers import RagRetriever, RagSequenceForGeneration, AutoModelForCausalLM, pipeline
import gradio as gr
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset_path = "./5k_index_data/my_knowledge_dataset"
index_path = "./5k_index_data/my_knowledge_dataset_hnsw_index.faiss"
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
passages_path = dataset_path,
index_path = index_path,
n_docs = 5)
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
rag_model.retriever.init_retrieval()
rag_model.to(device)
pipe = pipeline(
"text-generation",
model="google/gemma-2-2b-it",
model_kwargs={"torch_dtype": torch.bfloat16},
device=device,
)
def strip_title(title):
if title.startswith('"'):
title = title[1:]
if title.endswith('"'):
title = title[:-1]
return title
def retrieved_info(query, rag_model = rag_model):
# Tokenize Query
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
[query],
return_tensors = 'pt',
padding = True,
truncation = True,
)['input_ids'].to(device)
# Retrieve Documents
question_encoder_output = rag_model.rag.question_encoder(retriever_input_ids)
question_encoder_pool_output = question_encoder_output[0]
result = rag_model.retriever(
retriever_input_ids,
question_encoder_pool_output.cpu().detach().to(torch.float32).numpy(),
prefix = rag_model.rag.generator.config.prefix,
n_docs = rag_model.config.n_docs,
return_tensors = 'pt',
)
# Preparing query and retrieved docs for model
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
retrieved_context = []
for docs in all_docs:
titles = [strip_title(title) for title in docs['title']]
texts = docs['text']
for title, text in zip(titles, texts):
retrieved_context.append(f'{title}: {text}')
# Generating answer using gemma model
messages = [
{"role": "user", "content": f"{query}"},
{"role": "system" , "content": f"Context: {retrieved_context}. Use the links and information from the Context to answer the query in brief. Provide links in the answer."}
]
outputs = pipe(messages, max_new_tokens=256)
assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
return assistant_response
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens ,
temperature,
top_p,
):
if message: # If there's a user query
response = retrieved_info(message) # Get the answer from your local FAISS and Q&A model
return response
# In case no message, return an empty string
return ""
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# Custom title and description
title = "🧠 Welcome to Your AI Knowledge Assistant"
description = """
HI!!, I am your loyal assistant, y functionality is based on RAG model, I retrieves relevant information and provide answers based on that. Ask me any question, and let me assist you.
My capabilities are limited because I am still in development phase. I will do my best to assist you. SOOO LET'S BEGGINNNN......
"""
demo = gr.ChatInterface(
respond,
type = 'messages',
additional_inputs=[
gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title=title,
description=description,
submit_btn = True,
textbox=gr.Textbox(placeholder=["'What is the future of AI?' or 'App Development'"]),
examples=[["Future of AI"], ["App Development"]],
theme="compact",
)
if __name__ == "__main__":
demo.launch(share = True )
|