import yaml import fitz import torch import gradio as gr from PIL import Image from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain.document_loaders import PyPDFLoader from langchain.prompts import PromptTemplate from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import spaces from langchain_text_splitters import RecursiveCharacterTextSplitter from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType from datasets import Dataset, load_from_disk import faiss import numpy as np from pastebin_api import get_protected_content class RAGbot: def __init__(self, config_path="config.yaml"): self.processed = False self.page = 0 self.chat_history = [] self.prompt = None self.documents = None self.embeddings = None self.zilliz_vectordb = None self.hf_vectordb = None self.tokenizer = None self.model = None self.pipeline = None self.chain = None self.chunk_size = 512 self.overlap_percentage = 50 self.max_chunks_in_context = 2 self.current_context = None self.model_temperatue = 0.5 self.format_seperator = "\n\n--\n\n" self.pipe = None with open(config_path, "r") as file: config = yaml.safe_load(file) self.model_embeddings = config["modelEmbeddings"] self.auto_tokenizer = config["autoTokenizer"] self.auto_model_for_causal_lm = config["autoModelForCausalLM"] self.zilliz_config = config["zilliz"] self.persona_paste_key = config["personaPasteKey"] def connect_to_zilliz(self): connections.connect( host=self.zilliz_config["host"], port=self.zilliz_config["port"], user=self.zilliz_config["user"], password=self.zilliz_config["password"], secure=True ) self.zilliz_vectordb = Collection(self.zilliz_config["collection"]) def load_embeddings(self): self.embeddings = HuggingFaceEmbeddings(model_name=self.model_embeddings) def load_hf_vectordb(self, dataset_path, index_path): dataset = load_from_disk(dataset_path) index = faiss.read_index(index_path) self.hf_vectordb = (dataset, index) @spaces.GPU def load_tokenizer(self): self.tokenizer = AutoTokenizer.from_pretrained(self.auto_tokenizer) @spaces.GPU def create_organic_pipeline(self): self.pipe = pipeline( "text-generation", model=self.auto_model_for_causal_lm, model_kwargs={"torch_dtype": torch.bfloat16}, device="cuda", ) def get_organic_context(self, query, use_hf=False): if use_hf: dataset, index = self.hf_vectordb D, I = index.search(np.array([self.embeddings.embed_query(query)]), self.max_chunks_in_context) context = self.format_seperator.join([dataset[i] for i in I[0]]) else: result = self.zilliz_vectordb.search( data=[self.embeddings.embed_query(query)], anns_field="embeddings", param={"metric_type": "IP", "params": {"nprobe": 10}}, limit=self.max_chunks_in_context, expr=None, ) context = self.format_seperator.join([hit.entity.get('text') for hit in result[0]]) self.current_context = context def load_persona_data(self): persona_content = get_protected_content(self.persona_paste_key) persona_data = yaml.safe_load(persona_content) self.persona_text = persona_data["persona_text"] @spaces.GPU def create_organic_response(self, history, query, use_hf=False): self.get_organic_context(query, use_hf=use_hf) messages = [ {"role": "system", "content": f"Based on the given context, answer the user's question while maintaining the persona:\n{self.persona_text}\n\nContext:\n{self.current_context}"}, {"role": "user", "content": query}, ] prompt = self.pipe.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) temp = 0.1 outputs = self.pipe( prompt, max_new_tokens=1024, do_sample=True, temperature=temp, top_p=0.9, ) return outputs[0]["generated_text"][len(prompt):] def process_file(self, file): self.documents = PyPDFLoader(file.name).load() self.load_embeddings() self.connect_to_zilliz() @spaces.GPU def generate_response(self, history, query, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context, use_hf_index=False, hf_dataset_path=None, hf_index_path=None): self.chunk_size = chunk_size self.overlap_percentage = chunk_overlap_percentage self.model_temperatue = model_temperature self.max_chunks_in_context = max_chunks_in_context if not query: raise gr.Error(message='Submit a question') if use_hf_index: if not hf_dataset_path or not hf_index_path: raise gr.Error(message='Provide HuggingFace dataset and index paths') self.load_hf_vectordb(hf_dataset_path, hf_index_path) result = self.create_organic_response(history="", query=query, use_hf=True) else: if not file: raise gr.Error(message='Upload a PDF') if not self.processed: self.process_file(file) self.processed = True result = self.create_organic_response(history="", query=query) self.load_persona_data() result = f"{self.persona_text}\n\n{result}" for char in result: history[-1][-1] += char return history, "" def render_file(self, file, chunk_size, chunk_overlap_percentage, model_temperature, max_chunks_in_context): doc = fitz.open(file.name) page = doc[self.page] self.chunk_size = chunk_size self.overlap_percentage = chunk_overlap_percentage self.model_temperatue = model_temperature self.max_chunks_in_context = max_chunks_in_context pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples) return image def add_text(self, history, text): if not text: raise gr.Error('Enter text') history.append((text, '')) return history