DataChat / RAG.py
Jkalonji's picture
Adding a function to be used in gardio
fb5d8fe verified
raw
history blame
3.34 kB
import os
import re
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
from config import HUGGINGFACEHUB_API_TOKEN
from transformers import pipeline
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
# Vous pouvez choisir parmi les nombreux midèles disponibles sur HugginFace (https://huggingface.co/models)
model_name = "llmware/industry-bert-insurance-v0.1"
def remove_special_characters(string):
return re.sub(r"\n", " ", string)
def RAG_Langchain(query):
embeddings = SentenceTransformerEmbeddings(model_name=model_name)
repo_id = "llmware/bling-sheared-llama-1.3b-0.1"
loader = DirectoryLoader('data/', glob="**/*.pdf", show_progress=True, loader_cls=PyPDFLoader)
documents = loader.load()
# La taille des chunks est un paramètre important pour la qualité de l'information retrouvée. Il existe plusieurs méthodes
# pour en choisir la valeur.
# L'overlap correspond au nombre de caractères partagés entre un chunk et le chunk suivant
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
chunk = texts[0]
chunk.page_content = remove_special_characters(chunk.page_content)
#Data Preparation
for chunks in texts:
chunks.page_content = remove_special_characters(chunks.page_content)
# On charge tous les documents dans la base de données vectorielle, pour les utiliser ensuite
vector_stores=Chroma.from_documents(texts, embeddings, collection_metadata = {"hnsw:space": "cosine"}, persist_directory="stores/insurance_cosine")
#Retrieval
load_vector_store=Chroma(persist_directory="stores/insurance_cosine", embedding_function=embeddings)
#On prend pour l'instant k=1, on verra plus tard comment sélectionner les résultats de contexte
docs = load_vector_store.similarity_search_with_score(query=query, k=1)
results = {"Score":[],"Content":[],"Metadata":[]};
for i in docs:
doc, score = i
#print({"Score":score, "Content":doc.page_content, "Metadata":doc.metadata})
results['Score'].append(score)
results['Content'].append(doc.page_content)
results['Metadata'].append(doc.metadata)
context = results['Content']
return results
def generateResponseBasedOnContext(model_name, context_string, query):
question_answerer = pipeline("question-answering", model=model_name)
context_prompt = "You are a sports expert. Answer the user's question by using following context: "
context = context_prompt + context_string
print("context : ", context)
result = question_answerer(question=query, context=context)
return result['answer']
def gradio_adapted_RAG(model_name, query):
context = str(RAG_Langchain(query)['Content'])
generated_answer = generateResponseBasedOnContext(str(model_name),
context,
query)
return generated_answer