import torch import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer,util from transformers import AutoTokenizer , AutoModelForCausalLM class RAG: def __init__(self): self.model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.embedding_model_name = "all-mpnet-base-v2" self.embeddings_filename = "embeddings.csv" self.data_pd = pd.read_csv(self.embeddings_filename) self.data_dict = pd.read_csv(self.embeddings_filename).to_dict(orient='records') self.data_embeddings = self.get_embeddings() # Embedding model self.embedding_model = SentenceTransformer(model_name_or_path = self.embedding_model_name,device = self.device) # Tokenizer self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_id) # LLM self.llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=self.model_id).to(self.device) def get_embeddings(self) -> list: """Returns the embeddings from the csv file""" data_embeddings = [] for tensor_str in self.data_pd["embeddings"]: values_str = tensor_str.split("[")[1].split("]")[0] values_list = [float(val) for val in values_str.split(",")] tensor_result = torch.tensor(values_list) data_embeddings.append(tensor_result) data_embeddings = torch.stack(data_embeddings).to(self.device) return data_embeddings def retrieve_relevant_resource(self,user_query : str , k = 5): """Function to retrieve relevant resource""" query_embedding = self.embedding_model.encode(user_query, convert_to_tensor = True).to(self.device) dot_score = util.dot_score( a = query_embedding, b = self.data_embeddings)[0] score , idx = torch.topk(dot_score,k=k) return score,idx def prompt_formatter(self,query: str, context_items: list[dict]) -> str: """ Augments query with text-based context from context_items. """ # Join context items into one dotted paragraph context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items]) base_prompt = """You are a friendly lawyer chatbot who always responds in the style of a judge Based on the following context items, please answer the query. \nNow use the following context items to answer the user query: {context} \nRelevant passages: """ # Update base prompt with context items and query base_prompt = base_prompt.format(context=context) # Create prompt template for instruction-tuned model dialogue_template = [ { "role" : "system", "content" : base_prompt, }, { "role": "user", "content": query, }, ] # Apply the chat template prompt = self.tokenizer.apply_chat_template(conversation=dialogue_template, tokenize=False, add_generation_prompt=True) return prompt def query(self,user_text : str): scores, indices = self.retrieve_relevant_resource(user_text) context_items = [self.data_dict[i] for i in indices] prompt = self.prompt_formatter(query=user_text,context_items=context_items) input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) outputs = self.llm_model.generate(**input_ids,max_new_tokens=512) output_text = self.tokenizer.decode(outputs[0]) output_text = output_text.split("<|assistant|>") output_text = output_text[1].split("")[0] return output_text